summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-03-03 17:05:35 +0000
committerErik Johnston <erik@matrix.org>2020-03-20 15:31:52 +0000
commit32c656865a9d1a7bcd998fb872f6d08e18f39be2 (patch)
tree04673998c1337ee498efce45ee193934c1b46751 /synapse/replication/tcp
parentRemove unused token param from REPLICATE cmd (diff)
downloadsynapse-32c656865a9d1a7bcd998fb872f6d08e18f39be2.tar.xz
Always subscribe to all streams.
This already happens since the worker merge.
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/protocol.py80
1 files changed, 12 insertions, 68 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 817b84ad7f..b371d66ce7 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,17 +53,15 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
 
-from six import iteritems, iterkeys
+from six import iteritems
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
@@ -412,13 +410,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.streamer = streamer
 
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()  # type: Set[str]
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
-
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
@@ -434,20 +425,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         )
 
     async def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
-
-            await make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            await self.subscribe_to_stream(stream_name)
+        # Subscribe to all streams we're publishing to.
+        for stream_name in self.streamer.streams_by_name:
+            current_token = self.streamer.get_stream_token(stream_name)
+            self.send_command(PositionCommand(stream_name, current_token))
 
     async def on_FEDERATION_ACK(self, cmd):
         self.streamer.federation_ack(cmd.token)
@@ -471,37 +452,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    async def subscribe_to_stream(self, stream_name):
-        """Subscribe the remote to a stream.
-        """
-        self.replication_streams.discard(stream_name)
-
-        try:
-            # Get current stream position.
-            current_token = self.streamer.get_stream_token(stream_name)
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-
     def stream_update(self, stream_name, token, data):
         """Called when a new update is available to stream to clients.
 
         We need to check if the client is interested in the stream or not
         """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+        self.send_command(RdataCommand(stream_name, token, data))
 
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
@@ -604,7 +560,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
-        self.streams_connecting = set()  # type: Set[str]
+        self.streams_connecting = set(STREAMS_MAP)  # type: Set[str]
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
@@ -615,8 +571,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         BaseReplicationStreamProtocol.connectionMade(self)
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name in self.handler.get_streams_to_replicate():
-            self.replicate(stream_name)
+        self.replicate()
 
         # Tell the server if we have any users currently syncing (should only
         # happen on synchrotrons)
@@ -628,10 +583,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
 
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -706,19 +657,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
         self.handler.on_remote_server_up(cmd.data)
 
-    def replicate(self, stream_name):
+    def replicate(self):
         """Send the subscription request to the server
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r", self.id(), stream_name,
-        )
-
-        self.streams_connecting.add(stream_name)
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-        self.send_command(ReplicateCommand(stream_name))
+        self.send_command(ReplicateCommand("ALL"))
 
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)