diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 30d8de48fa..f88e0a2e40 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -14,9 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from prometheus_client import Counter
+from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
@@ -44,7 +56,6 @@ from synapse.replication.tcp.streams import (
Stream,
TypingStream,
)
-from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -62,6 +73,12 @@ invalidate_cache_counter = Counter(
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+# the type of the entries in _command_queues_by_stream
+_StreamCommandQueue = Deque[
+ Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+]
+
+
class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands
back out to connections.
@@ -116,10 +133,6 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
- self._position_linearizer = Linearizer(
- "replication_position", clock=self._clock
- )
-
# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
@@ -131,10 +144,6 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming stream names that are coming from
- # that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
-
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -142,6 +151,32 @@ class ReplicationCommandHandler:
lambda: len(self._connections),
)
+ # When POSITION or RDATA commands arrive, we stick them in a queue and process
+ # them in order in a separate background process.
+
+ # the streams which are currently being processed by _unsafe_process_stream
+ self._processing_streams = set() # type: Set[str]
+
+ # for each stream, a queue of commands that are awaiting processing, and the
+ # connection that they arrived on.
+ self._command_queues_by_stream = {
+ stream_name: _StreamCommandQueue() for stream_name in self._streams
+ }
+
+ # For each connection, the incoming stream names that have received a POSITION
+ # from that connection.
+ self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+
+ LaterGauge(
+ "synapse_replication_tcp_command_queue",
+ "Number of inbound RDATA/POSITION commands queued for processing",
+ ["stream_name"],
+ lambda: {
+ (stream_name,): len(queue)
+ for stream_name, queue in self._command_queues_by_stream.items()
+ },
+ )
+
self._is_master = hs.config.worker_app is None
self._federation_sender = None
@@ -152,6 +187,64 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ async def _add_command_to_stream_queue(
+ self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ ) -> None:
+ """Queue the given received command for processing
+
+ Adds the given command to the per-stream queue, and processes the queue if
+ necessary
+ """
+ stream_name = cmd.stream_name
+ queue = self._command_queues_by_stream.get(stream_name)
+ if queue is None:
+ logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
+ return
+
+ # if we're already processing this stream, stick the new command in the
+ # queue, and we're done.
+ if stream_name in self._processing_streams:
+ queue.append((cmd, conn))
+ return
+
+ # otherwise, process the new command.
+
+ # arguably we should start off a new background process here, but nothing
+ # will be too upset if we don't return for ages, so let's save the overhead
+ # and use the existing logcontext.
+
+ self._processing_streams.add(stream_name)
+ try:
+ # might as well skip the queue for this one, since it must be empty
+ assert not queue
+ await self._process_command(cmd, conn, stream_name)
+
+ # now process any other commands that have built up while we were
+ # dealing with that one.
+ while queue:
+ cmd, conn = queue.popleft()
+ try:
+ await self._process_command(cmd, conn, stream_name)
+ except Exception:
+ logger.exception("Failed to handle command %s", cmd)
+
+ finally:
+ self._processing_streams.discard(stream_name)
+
+ async def _process_command(
+ self,
+ cmd: Union[PositionCommand, RdataCommand],
+ conn: AbstractConnection,
+ stream_name: str,
+ ) -> None:
+ if isinstance(cmd, PositionCommand):
+ await self._process_position(stream_name, conn, cmd)
+ elif isinstance(cmd, RdataCommand):
+ await self._process_rdata(stream_name, conn, cmd)
+ else:
+ # This shouldn't be possible
+ raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
+
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
@@ -285,63 +378,71 @@ class ReplicationCommandHandler:
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
- raise
-
- # We linearize here for two reasons:
+ # We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
- with await self._position_linearizer.queue(cmd.stream_name):
- # make sure that we've processed a POSITION for this stream *on this
- # connection*. (A POSITION on another connection is no good, as there
- # is no guarantee that we have seen all the intermediate updates.)
- sbc = self._streams_by_connection.get(conn)
- if not sbc or stream_name not in sbc:
- # Let's drop the row for now, on the assumption we'll receive a
- # `POSITION` soon and we'll catch up correctly then.
- logger.debug(
- "Discarding RDATA for unconnected stream %s -> %s",
- stream_name,
- cmd.token,
- )
- return
-
- if cmd.token is None:
- # I.e. this is part of a batch of updates for this stream (in
- # which case batch until we get an update for the stream with a non
- # None token).
- self._pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self._pending_batches.pop(stream_name, [])
- rows.append(row)
-
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got RDATA for unknown stream: %s", stream_name)
- return
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # Discard this data if this token is earlier than the current
- # position. Note that streams can be reset (in which case you
- # expect an earlier token), but that must be preceded by a
- # POSITION command.
- if cmd.token <= current_token:
- logger.debug(
- "Discarding RDATA from stream %s at position %s before previous position %s",
- stream_name,
- cmd.token,
- current_token,
- )
- else:
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ await self._add_command_to_stream_queue(conn, cmd)
+
+ async def _process_rdata(
+ self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ ) -> None:
+ """Process an RDATA command
+
+ Called after the command has been popped off the queue of inbound commands
+ """
+ try:
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+ except Exception as e:
+ raise Exception(
+ "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
+ ) from e
+
+ # make sure that we've processed a POSITION for this stream *on this
+ # connection*. (A POSITION on another connection is no good, as there
+ # is no guarantee that we have seen all the intermediate updates.)
+ sbc = self._streams_by_connection.get(conn)
+ if not sbc or stream_name not in sbc:
+ # Let's drop the row for now, on the assumption we'll receive a
+ # `POSITION` soon and we'll catch up correctly then.
+ logger.debug(
+ "Discarding RDATA for unconnected stream %s -> %s",
+ stream_name,
+ cmd.token,
+ )
+ return
+
+ if cmd.token is None:
+ # I.e. this is part of a batch of updates for this stream (in
+ # which case batch until we get an update for the stream with a non
+ # None token).
+ self._pending_batches.setdefault(stream_name, []).append(row)
+ return
+
+ # Check if this is the last of a batch of updates
+ rows = self._pending_batches.pop(stream_name, [])
+ rows.append(row)
+
+ stream = self._streams[stream_name]
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -367,67 +468,65 @@ class ReplicationCommandHandler:
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
- stream_name = cmd.stream_name
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got POSITION for unknown stream: %s", stream_name)
- return
+ await self._add_command_to_stream_queue(conn, cmd)
- # We protect catching up with a linearizer in case the replication
- # connection reconnects under us.
- with await self._position_linearizer.queue(stream_name):
- # We're about to go and catch up with the stream, so remove from set
- # of connected streams.
- for streams in self._streams_by_connection.values():
- streams.discard(stream_name)
-
- # We clear the pending batches for the stream as the fetching of the
- # missing updates below will fetch all rows in the batch.
- self._pending_batches.pop(stream_name, [])
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # If the position token matches our current token then we're up to
- # date and there's nothing to do. Otherwise, fetch all updates
- # between then and now.
- missing_updates = cmd.token != current_token
- while missing_updates:
- logger.info(
- "Fetching replication rows for '%s' between %i and %i",
- stream_name,
- current_token,
- cmd.token,
- )
- (
- updates,
- current_token,
- missing_updates,
- ) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
- )
+ async def _process_position(
+ self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ ) -> None:
+ """Process a POSITION command
- # TODO: add some tests for this
+ Called after the command has been popped off the queue of inbound commands
+ """
+ stream = self._streams[stream_name]
- # Some streams return multiple rows with the same stream IDs,
- # which need to be processed in batches.
+ # We're about to go and catch up with the stream, so remove from set
+ # of connected streams.
+ for streams in self._streams_by_connection.values():
+ streams.discard(stream_name)
- for token, rows in _batch_updates(updates):
- await self.on_rdata(
- stream_name,
- cmd.instance_name,
- token,
- [stream.parse_row(row) for row in rows],
- )
+ # We clear the pending batches for the stream as the fetching of the
+ # missing updates below will fetch all rows in the batch.
+ self._pending_batches.pop(stream_name, [])
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
- # We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ # If the position token matches our current token then we're up to
+ # date and there's nothing to do. Otherwise, fetch all updates
+ # between then and now.
+ missing_updates = cmd.token != current_token
+ while missing_updates:
+ logger.info(
+ "Fetching replication rows for '%s' between %i and %i",
+ stream_name,
+ current_token,
+ cmd.token,
+ )
+ (updates, current_token, missing_updates) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
)
- self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+ # TODO: add some tests for this
+
+ # Some streams return multiple rows with the same stream IDs,
+ # which need to be processed in batches.
+
+ for token, rows in _batch_updates(updates):
+ await self.on_rdata(
+ stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
+ )
+
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+ # We've now caught up to position sent to us, notify handler.
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
+
+ self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
|