summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/replication/tcp/handler.py331
1 files changed, 215 insertions, 116 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 30d8de48fa..f88e0a2e40 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -14,9 +14,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+from typing import (
+    Any,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+)
 
 from prometheus_client import Counter
+from typing_extensions import Deque
 
 from twisted.internet.protocol import ReconnectingClientFactory
 
@@ -44,7 +56,6 @@ from synapse.replication.tcp.streams import (
     Stream,
     TypingStream,
 )
-from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
@@ -62,6 +73,12 @@ invalidate_cache_counter = Counter(
 user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 
+# the type of the entries in _command_queues_by_stream
+_StreamCommandQueue = Deque[
+    Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+]
+
+
 class ReplicationCommandHandler:
     """Handles incoming commands from replication as well as sending commands
     back out to connections.
@@ -116,10 +133,6 @@ class ReplicationCommandHandler:
 
             self._streams_to_replicate.append(stream)
 
-        self._position_linearizer = Linearizer(
-            "replication_position", clock=self._clock
-        )
-
         # Map of stream name to batched updates. See RdataCommand for info on
         # how batching works.
         self._pending_batches = {}  # type: Dict[str, List[Any]]
@@ -131,10 +144,6 @@ class ReplicationCommandHandler:
         # outgoing replication commands to.)
         self._connections = []  # type: List[AbstractConnection]
 
-        # For each connection, the incoming stream names that are coming from
-        # that connection.
-        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
-
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
             "",
@@ -142,6 +151,32 @@ class ReplicationCommandHandler:
             lambda: len(self._connections),
         )
 
+        # When POSITION or RDATA commands arrive, we stick them in a queue and process
+        # them in order in a separate background process.
+
+        # the streams which are currently being processed by _unsafe_process_stream
+        self._processing_streams = set()  # type: Set[str]
+
+        # for each stream, a queue of commands that are awaiting processing, and the
+        # connection that they arrived on.
+        self._command_queues_by_stream = {
+            stream_name: _StreamCommandQueue() for stream_name in self._streams
+        }
+
+        # For each connection, the incoming stream names that have received a POSITION
+        # from that connection.
+        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+
+        LaterGauge(
+            "synapse_replication_tcp_command_queue",
+            "Number of inbound RDATA/POSITION commands queued for processing",
+            ["stream_name"],
+            lambda: {
+                (stream_name,): len(queue)
+                for stream_name, queue in self._command_queues_by_stream.items()
+            },
+        )
+
         self._is_master = hs.config.worker_app is None
 
         self._federation_sender = None
@@ -152,6 +187,64 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
+    async def _add_command_to_stream_queue(
+        self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+    ) -> None:
+        """Queue the given received command for processing
+
+        Adds the given command to the per-stream queue, and processes the queue if
+        necessary
+        """
+        stream_name = cmd.stream_name
+        queue = self._command_queues_by_stream.get(stream_name)
+        if queue is None:
+            logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
+            return
+
+        # if we're already processing this stream, stick the new command in the
+        # queue, and we're done.
+        if stream_name in self._processing_streams:
+            queue.append((cmd, conn))
+            return
+
+        # otherwise, process the new command.
+
+        # arguably we should start off a new background process here, but nothing
+        # will be too upset if we don't return for ages, so let's save the overhead
+        # and use the existing logcontext.
+
+        self._processing_streams.add(stream_name)
+        try:
+            # might as well skip the queue for this one, since it must be empty
+            assert not queue
+            await self._process_command(cmd, conn, stream_name)
+
+            # now process any other commands that have built up while we were
+            # dealing with that one.
+            while queue:
+                cmd, conn = queue.popleft()
+                try:
+                    await self._process_command(cmd, conn, stream_name)
+                except Exception:
+                    logger.exception("Failed to handle command %s", cmd)
+
+        finally:
+            self._processing_streams.discard(stream_name)
+
+    async def _process_command(
+        self,
+        cmd: Union[PositionCommand, RdataCommand],
+        conn: AbstractConnection,
+        stream_name: str,
+    ) -> None:
+        if isinstance(cmd, PositionCommand):
+            await self._process_position(stream_name, conn, cmd)
+        elif isinstance(cmd, RdataCommand):
+            await self._process_rdata(stream_name, conn, cmd)
+        else:
+            # This shouldn't be possible
+            raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
+
     def start_replication(self, hs):
         """Helper method to start a replication connection to the remote server
         using TCP.
@@ -285,63 +378,71 @@ class ReplicationCommandHandler:
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
-            raise
-
-        # We linearize here for two reasons:
+        # We put the received command into a queue here for two reasons:
         #   1. so we don't try and concurrently handle multiple rows for the
         #      same stream, and
         #   2. so we don't race with getting a POSITION command and fetching
         #      missing RDATA.
-        with await self._position_linearizer.queue(cmd.stream_name):
-            # make sure that we've processed a POSITION for this stream *on this
-            # connection*. (A POSITION on another connection is no good, as there
-            # is no guarantee that we have seen all the intermediate updates.)
-            sbc = self._streams_by_connection.get(conn)
-            if not sbc or stream_name not in sbc:
-                # Let's drop the row for now, on the assumption we'll receive a
-                # `POSITION` soon and we'll catch up correctly then.
-                logger.debug(
-                    "Discarding RDATA for unconnected stream %s -> %s",
-                    stream_name,
-                    cmd.token,
-                )
-                return
-
-            if cmd.token is None:
-                # I.e. this is part of a batch of updates for this stream (in
-                # which case batch until we get an update for the stream with a non
-                # None token).
-                self._pending_batches.setdefault(stream_name, []).append(row)
-            else:
-                # Check if this is the last of a batch of updates
-                rows = self._pending_batches.pop(stream_name, [])
-                rows.append(row)
-
-                stream = self._streams.get(stream_name)
-                if not stream:
-                    logger.error("Got RDATA for unknown stream: %s", stream_name)
-                    return
-
-                # Find where we previously streamed up to.
-                current_token = stream.current_token(cmd.instance_name)
-
-                # Discard this data if this token is earlier than the current
-                # position. Note that streams can be reset (in which case you
-                # expect an earlier token), but that must be preceded by a
-                # POSITION command.
-                if cmd.token <= current_token:
-                    logger.debug(
-                        "Discarding RDATA from stream %s at position %s before previous position %s",
-                        stream_name,
-                        cmd.token,
-                        current_token,
-                    )
-                else:
-                    await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+        await self._add_command_to_stream_queue(conn, cmd)
+
+    async def _process_rdata(
+        self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+    ) -> None:
+        """Process an RDATA command
+
+        Called after the command has been popped off the queue of inbound commands
+        """
+        try:
+            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+        except Exception as e:
+            raise Exception(
+                "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
+            ) from e
+
+        # make sure that we've processed a POSITION for this stream *on this
+        # connection*. (A POSITION on another connection is no good, as there
+        # is no guarantee that we have seen all the intermediate updates.)
+        sbc = self._streams_by_connection.get(conn)
+        if not sbc or stream_name not in sbc:
+            # Let's drop the row for now, on the assumption we'll receive a
+            # `POSITION` soon and we'll catch up correctly then.
+            logger.debug(
+                "Discarding RDATA for unconnected stream %s -> %s",
+                stream_name,
+                cmd.token,
+            )
+            return
+
+        if cmd.token is None:
+            # I.e. this is part of a batch of updates for this stream (in
+            # which case batch until we get an update for the stream with a non
+            # None token).
+            self._pending_batches.setdefault(stream_name, []).append(row)
+            return
+
+        # Check if this is the last of a batch of updates
+        rows = self._pending_batches.pop(stream_name, [])
+        rows.append(row)
+
+        stream = self._streams[stream_name]
+
+        # Find where we previously streamed up to.
+        current_token = stream.current_token(cmd.instance_name)
+
+        # Discard this data if this token is earlier than the current
+        # position. Note that streams can be reset (in which case you
+        # expect an earlier token), but that must be preceded by a
+        # POSITION command.
+        if cmd.token <= current_token:
+            logger.debug(
+                "Discarding RDATA from stream %s at position %s before previous position %s",
+                stream_name,
+                cmd.token,
+                current_token,
+            )
+        else:
+            await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -367,67 +468,65 @@ class ReplicationCommandHandler:
 
         logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
 
-        stream_name = cmd.stream_name
-        stream = self._streams.get(stream_name)
-        if not stream:
-            logger.error("Got POSITION for unknown stream: %s", stream_name)
-            return
+        await self._add_command_to_stream_queue(conn, cmd)
 
-        # We protect catching up with a linearizer in case the replication
-        # connection reconnects under us.
-        with await self._position_linearizer.queue(stream_name):
-            # We're about to go and catch up with the stream, so remove from set
-            # of connected streams.
-            for streams in self._streams_by_connection.values():
-                streams.discard(stream_name)
-
-            # We clear the pending batches for the stream as the fetching of the
-            # missing updates below will fetch all rows in the batch.
-            self._pending_batches.pop(stream_name, [])
-
-            # Find where we previously streamed up to.
-            current_token = stream.current_token(cmd.instance_name)
-
-            # If the position token matches our current token then we're up to
-            # date and there's nothing to do. Otherwise, fetch all updates
-            # between then and now.
-            missing_updates = cmd.token != current_token
-            while missing_updates:
-                logger.info(
-                    "Fetching replication rows for '%s' between %i and %i",
-                    stream_name,
-                    current_token,
-                    cmd.token,
-                )
-                (
-                    updates,
-                    current_token,
-                    missing_updates,
-                ) = await stream.get_updates_since(
-                    cmd.instance_name, current_token, cmd.token
-                )
+    async def _process_position(
+        self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+    ) -> None:
+        """Process a POSITION command
 
-                # TODO: add some tests for this
+        Called after the command has been popped off the queue of inbound commands
+        """
+        stream = self._streams[stream_name]
 
-                # Some streams return multiple rows with the same stream IDs,
-                # which need to be processed in batches.
+        # We're about to go and catch up with the stream, so remove from set
+        # of connected streams.
+        for streams in self._streams_by_connection.values():
+            streams.discard(stream_name)
 
-                for token, rows in _batch_updates(updates):
-                    await self.on_rdata(
-                        stream_name,
-                        cmd.instance_name,
-                        token,
-                        [stream.parse_row(row) for row in rows],
-                    )
+        # We clear the pending batches for the stream as the fetching of the
+        # missing updates below will fetch all rows in the batch.
+        self._pending_batches.pop(stream_name, [])
 
-            logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+        # Find where we previously streamed up to.
+        current_token = stream.current_token(cmd.instance_name)
 
-            # We've now caught up to position sent to us, notify handler.
-            await self._replication_data_handler.on_position(
-                cmd.stream_name, cmd.instance_name, cmd.token
+        # If the position token matches our current token then we're up to
+        # date and there's nothing to do. Otherwise, fetch all updates
+        # between then and now.
+        missing_updates = cmd.token != current_token
+        while missing_updates:
+            logger.info(
+                "Fetching replication rows for '%s' between %i and %i",
+                stream_name,
+                current_token,
+                cmd.token,
+            )
+            (updates, current_token, missing_updates) = await stream.get_updates_since(
+                cmd.instance_name, current_token, cmd.token
             )
 
-            self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+            # TODO: add some tests for this
+
+            # Some streams return multiple rows with the same stream IDs,
+            # which need to be processed in batches.
+
+            for token, rows in _batch_updates(updates):
+                await self.on_rdata(
+                    stream_name,
+                    cmd.instance_name,
+                    token,
+                    [stream.parse_row(row) for row in rows],
+                )
+
+        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+        # We've now caught up to position sent to us, notify handler.
+        await self._replication_data_handler.on_position(
+            cmd.stream_name, cmd.instance_name, cmd.token
+        )
+
+        self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
     async def on_REMOTE_SERVER_UP(
         self, conn: AbstractConnection, cmd: RemoteServerUpCommand