summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py206
1 files changed, 71 insertions, 135 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..f81d2e2442 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
     > PING 1490197665618
     < NAME synapse.app.appservice
     < PING 1490197665618
-    < REPLICATE events 1
-    < REPLICATE backfill 1
-    < REPLICATE caches 1
+    < REPLICATE
     > POSITION events 1
     > POSITION backfill 1
     > POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
 
-from six import iteritems, iterkeys
+from six import iteritems
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
     SyncCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
+MYPY = False
+if MYPY:
+    from synapse.server import HomeServer
+
+
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.streamer = streamer
 
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()  # type: Set[str]
-
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()  # type:  Set[str]
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
-
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
@@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         )
 
     async def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-        token = cmd.token
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream, token)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
-
-            await make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            await self.subscribe_to_stream(stream_name, token)
+        # Subscribe to all streams we're publishing to.
+        for stream_name in self.streamer.streams_by_name:
+            current_token = self.streamer.get_stream_token(stream_name)
+            self.send_command(PositionCommand(stream_name, current_token))
 
     async def on_FEDERATION_ACK(self, cmd):
         self.streamer.federation_ack(cmd.token)
@@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    async def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a stream.
-
-        This invloves checking if they've missed anything and sending those
-        updates down if they have. During that time new updates for the stream
-        are queued and sent once we've sent down any missed updates.
-        """
-        self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
-
-        try:
-            # Get missing updates
-            updates, current_token = await self.streamer.get_stream_updates(
-                stream_name, token
-            )
-
-            # Send all the missing updates
-            for update in updates:
-                token, row = update[0], update[1]
-                self.send_command(RdataCommand(stream_name, token, row))
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
-
     def stream_update(self, stream_name, token, data):
         """Called when a new update is available to stream to clients.
 
         We need to check if the client is interested in the stream or not
         """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+        self.send_command(RdataCommand(stream_name, token, data))
 
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
@@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def __init__(
         self,
+        hs: "HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
@@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.handler = handler
 
+        self.streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
-        self.streams_connecting = set()  # type: Set[str]
+        self.streams_connecting = set(STREAMS_MAP)  # type: Set[str]
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
-        self.pending_batches = {}  # type: Dict[str, Any]
+        self.pending_batches = {}  # type: Dict[str, List[Any]]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
         BaseReplicationStreamProtocol.connectionMade(self)
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
-            self.replicate(stream_name, token)
+        self.replicate()
 
         # Tell the server if we have any users currently syncing (should only
         # happen on synchrotrons)
@@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
 
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             )
             raise
 
-        if cmd.token is None:
+        if cmd.token is None or stream_name in self.streams_connecting:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
             self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             rows.append(row)
             await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    async def on_POSITION(self, cmd):
-        # When we get a `POSITION` command it means we've finished getting
-        # missing updates for the given stream, and are now up to date.
+    async def on_POSITION(self, cmd: PositionCommand):
+        stream = self.streams.get(cmd.stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+            return
+
+        # Find where we previously streamed up to.
+        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+        if current_token is None:
+            logger.warning(
+                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+            )
+            return
+
+        # Fetch all updates between then and now.
+        limited = True
+        while limited:
+            updates, current_token, limited = await stream.get_updates_since(
+                current_token, cmd.token
+            )
+
+            # Check if the connection was closed underneath us, if so we bail
+            # rather than risk having concurrent catch ups going on.
+            if self.state == ConnectionStates.CLOSED:
+                return
+
+            if updates:
+                await self.handler.on_rdata(
+                    cmd.stream_name,
+                    current_token,
+                    [stream.parse_row(update[1]) for update in updates],
+                )
+
+        # We've now caught up to position sent to us, notify handler.
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        # Check if the connection was closed underneath us, if so we bail
+        # rather than risk having concurrent catch ups going on.
+        if self.state == ConnectionStates.CLOSED:
+            return
+
+        # Handle any RDATA that came in while we were catching up.
+        rows = self.pending_batches.pop(cmd.stream_name, [])
+        if rows:
+            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
 
     async def on_SYNC(self, cmd):
         self.handler.on_sync(cmd.data)
@@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
         self.handler.on_remote_server_up(cmd.data)
 
-    def replicate(self, stream_name, token):
+    def replicate(self):
         """Send the subscription request to the server
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r from %r",
-            self.id(),
-            stream_name,
-            token,
-        )
-
-        self.streams_connecting.add(stream_name)
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-        self.send_command(ReplicateCommand(stream_name, token))
+        self.send_command(ReplicateCommand())
 
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)