summary refs log tree commit diff
path: root/synapse/replication/tcp/handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/handler.py')
-rw-r--r--synapse/replication/tcp/handler.py413
1 files changed, 279 insertions, 134 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..1c303f3a46 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -13,15 +13,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+from typing import (
+    Any,
+    Awaitable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+)
 
 from prometheus_client import Counter
+from typing_extensions import Deque
 
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
 from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
@@ -43,8 +56,8 @@ from synapse.replication.tcp.streams import (
     EventsStream,
     FederationStream,
     Stream,
+    TypingStream,
 )
-from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
@@ -56,12 +69,16 @@ inbound_rdata_count = Counter(
 user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
 federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
 remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
-    "synapse_replication_tcp_resource_invalidate_cache", ""
-)
+
 user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 
+# the type of the entries in _command_queues_by_stream
+_StreamCommandQueue = Deque[
+    Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+]
+
+
 class ReplicationCommandHandler:
     """Handles incoming commands from replication as well as sending commands
     back out to connections.
@@ -97,6 +114,14 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, TypingStream):
+                # Only add TypingStream as a source on the instance in charge of
+                # typing.
+                if hs.config.worker.writers.typing == hs.get_instance_name():
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
@@ -108,12 +133,8 @@ class ReplicationCommandHandler:
 
             self._streams_to_replicate.append(stream)
 
-        self._position_linearizer = Linearizer(
-            "replication_position", clock=self._clock
-        )
-
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
+        # Map of stream name to batched updates. See RdataCommand for info on
+        # how batching works.
         self._pending_batches = {}  # type: Dict[str, List[Any]]
 
         # The factory used to create connections.
@@ -123,9 +144,6 @@ class ReplicationCommandHandler:
         # outgoing replication commands to.)
         self._connections = []  # type: List[AbstractConnection]
 
-        # For each connection, the incoming streams that are coming from that connection
-        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
-
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
             "",
@@ -133,6 +151,32 @@ class ReplicationCommandHandler:
             lambda: len(self._connections),
         )
 
+        # When POSITION or RDATA commands arrive, we stick them in a queue and process
+        # them in order in a separate background process.
+
+        # the streams which are currently being processed by _unsafe_process_queue
+        self._processing_streams = set()  # type: Set[str]
+
+        # for each stream, a queue of commands that are awaiting processing, and the
+        # connection that they arrived on.
+        self._command_queues_by_stream = {
+            stream_name: _StreamCommandQueue() for stream_name in self._streams
+        }
+
+        # For each connection, the incoming stream names that have received a POSITION
+        # from that connection.
+        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+
+        LaterGauge(
+            "synapse_replication_tcp_command_queue",
+            "Number of inbound RDATA/POSITION commands queued for processing",
+            ["stream_name"],
+            lambda: {
+                (stream_name,): len(queue)
+                for stream_name, queue in self._command_queues_by_stream.items()
+            },
+        )
+
         self._is_master = hs.config.worker_app is None
 
         self._federation_sender = None
@@ -143,15 +187,75 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
+    def _add_command_to_stream_queue(
+        self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+    ) -> None:
+        """Queue the given received command for processing
+
+        Adds the given command to the per-stream queue, and processes the queue if
+        necessary
+        """
+        stream_name = cmd.stream_name
+        queue = self._command_queues_by_stream.get(stream_name)
+        if queue is None:
+            logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
+            return
+
+        queue.append((cmd, conn))
+
+        # if we're already processing this stream, there's nothing more to do:
+        # the new entry on the queue will get picked up in due course
+        if stream_name in self._processing_streams:
+            return
+
+        # fire off a background process to start processing the queue.
+        run_as_background_process(
+            "process-replication-data", self._unsafe_process_queue, stream_name
+        )
+
+    async def _unsafe_process_queue(self, stream_name: str):
+        """Processes the command queue for the given stream, until it is empty
+
+        Does not check if there is already a thread processing the queue, hence "unsafe"
+        """
+        assert stream_name not in self._processing_streams
+
+        self._processing_streams.add(stream_name)
+        try:
+            queue = self._command_queues_by_stream.get(stream_name)
+            while queue:
+                cmd, conn = queue.popleft()
+                try:
+                    await self._process_command(cmd, conn, stream_name)
+                except Exception:
+                    logger.exception("Failed to handle command %s", cmd)
+        finally:
+            self._processing_streams.discard(stream_name)
+
+    async def _process_command(
+        self,
+        cmd: Union[PositionCommand, RdataCommand],
+        conn: AbstractConnection,
+        stream_name: str,
+    ) -> None:
+        if isinstance(cmd, PositionCommand):
+            await self._process_position(stream_name, conn, cmd)
+        elif isinstance(cmd, RdataCommand):
+            await self._process_rdata(stream_name, conn, cmd)
+        else:
+            # This shouldn't be possible
+            raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
+
     def start_replication(self, hs):
         """Helper method to start a replication connection to the remote server
         using TCP.
         """
         if hs.config.redis.redis_enabled:
+            import txredisapi
+
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
             )
-            import txredisapi
 
             logger.info(
                 "Connecting to redis (host=%r port=%r)",
@@ -198,7 +302,7 @@ class ReplicationCommandHandler:
         """
         return self._streams_to_replicate
 
-    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
     def send_positions_to_connection(self, conn: AbstractConnection):
@@ -217,57 +321,73 @@ class ReplicationCommandHandler:
                 )
             )
 
-    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+    def on_USER_SYNC(
+        self, conn: AbstractConnection, cmd: UserSyncCommand
+    ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
         if self._is_master:
-            await self._presence_handler.update_external_syncs_row(
+            return self._presence_handler.update_external_syncs_row(
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
+        else:
+            return None
 
-    async def on_CLEAR_USER_SYNC(
+    def on_CLEAR_USER_SYNC(
         self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         if self._is_master:
-            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+            return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+        else:
+            return None
 
-    async def on_FEDERATION_ACK(
-        self, conn: AbstractConnection, cmd: FederationAckCommand
-    ):
+    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
         federation_ack_counter.inc()
 
         if self._federation_sender:
-            self._federation_sender.federation_ack(cmd.token)
+            self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
-    async def on_REMOVE_PUSHER(
+    def on_REMOVE_PUSHER(
         self, conn: AbstractConnection, cmd: RemovePusherCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         remove_pusher_counter.inc()
 
         if self._is_master:
-            await self._store.delete_pusher_by_app_id_pushkey_user_id(
-                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
-            )
+            return self._handle_remove_pusher(cmd)
+        else:
+            return None
+
+    async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+        await self._store.delete_pusher_by_app_id_pushkey_user_id(
+            app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+        )
 
-            self._notifier.on_new_replication_data()
+        self._notifier.on_new_replication_data()
 
-    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+    def on_USER_IP(
+        self, conn: AbstractConnection, cmd: UserIpCommand
+    ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
         if self._is_master:
-            await self._store.insert_client_ip(
-                cmd.user_id,
-                cmd.access_token,
-                cmd.ip,
-                cmd.user_agent,
-                cmd.device_id,
-                cmd.last_seen,
-            )
+            return self._handle_user_ip(cmd)
+        else:
+            return None
+
+    async def _handle_user_ip(self, cmd: UserIpCommand):
+        await self._store.insert_client_ip(
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
+            cmd.last_seen,
+        )
 
-        if self._server_notices_sender:
-            await self._server_notices_sender.on_user_ip(cmd.user_id)
+        assert self._server_notices_sender is not None
+        await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -275,42 +395,71 @@ class ReplicationCommandHandler:
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
-            raise
-
-        # We linearize here for two reasons:
+        # We put the received command into a queue here for two reasons:
         #   1. so we don't try and concurrently handle multiple rows for the
         #      same stream, and
         #   2. so we don't race with getting a POSITION command and fetching
         #      missing RDATA.
-        with await self._position_linearizer.queue(cmd.stream_name):
-            # make sure that we've processed a POSITION for this stream *on this
-            # connection*. (A POSITION on another connection is no good, as there
-            # is no guarantee that we have seen all the intermediate updates.)
-            sbc = self._streams_by_connection.get(conn)
-            if not sbc or stream_name not in sbc:
-                # Let's drop the row for now, on the assumption we'll receive a
-                # `POSITION` soon and we'll catch up correctly then.
-                logger.debug(
-                    "Discarding RDATA for unconnected stream %s -> %s",
-                    stream_name,
-                    cmd.token,
-                )
-                return
-
-            if cmd.token is None:
-                # I.e. this is part of a batch of updates for this stream (in
-                # which case batch until we get an update for the stream with a non
-                # None token).
-                self._pending_batches.setdefault(stream_name, []).append(row)
-            else:
-                # Check if this is the last of a batch of updates
-                rows = self._pending_batches.pop(stream_name, [])
-                rows.append(row)
-                await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+        self._add_command_to_stream_queue(conn, cmd)
+
+    async def _process_rdata(
+        self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+    ) -> None:
+        """Process an RDATA command
+
+        Called after the command has been popped off the queue of inbound commands
+        """
+        try:
+            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+        except Exception as e:
+            raise Exception(
+                "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
+            ) from e
+
+        # make sure that we've processed a POSITION for this stream *on this
+        # connection*. (A POSITION on another connection is no good, as there
+        # is no guarantee that we have seen all the intermediate updates.)
+        sbc = self._streams_by_connection.get(conn)
+        if not sbc or stream_name not in sbc:
+            # Let's drop the row for now, on the assumption we'll receive a
+            # `POSITION` soon and we'll catch up correctly then.
+            logger.debug(
+                "Discarding RDATA for unconnected stream %s -> %s",
+                stream_name,
+                cmd.token,
+            )
+            return
+
+        if cmd.token is None:
+            # I.e. this is part of a batch of updates for this stream (in
+            # which case batch until we get an update for the stream with a non
+            # None token).
+            self._pending_batches.setdefault(stream_name, []).append(row)
+            return
+
+        # Check if this is the last of a batch of updates
+        rows = self._pending_batches.pop(stream_name, [])
+        rows.append(row)
+
+        stream = self._streams[stream_name]
+
+        # Find where we previously streamed up to.
+        current_token = stream.current_token(cmd.instance_name)
+
+        # Discard this data if this token is earlier than the current
+        # position. Note that streams can be reset (in which case you
+        # expect an earlier token), but that must be preceded by a
+        # POSITION command.
+        if cmd.token <= current_token:
+            logger.debug(
+                "Discarding RDATA from stream %s at position %s before previous position %s",
+                stream_name,
+                cmd.token,
+                current_token,
+            )
+        else:
+            await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -329,78 +478,74 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
 
         logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
 
-        stream_name = cmd.stream_name
-        stream = self._streams.get(stream_name)
-        if not stream:
-            logger.error("Got POSITION for unknown stream: %s", stream_name)
-            return
+        self._add_command_to_stream_queue(conn, cmd)
 
-        # We protect catching up with a linearizer in case the replication
-        # connection reconnects under us.
-        with await self._position_linearizer.queue(stream_name):
-            # We're about to go and catch up with the stream, so remove from set
-            # of connected streams.
-            for streams in self._streams_by_connection.values():
-                streams.discard(stream_name)
-
-            # We clear the pending batches for the stream as the fetching of the
-            # missing updates below will fetch all rows in the batch.
-            self._pending_batches.pop(stream_name, [])
-
-            # Find where we previously streamed up to.
-            current_token = stream.current_token(cmd.instance_name)
-
-            # If the position token matches our current token then we're up to
-            # date and there's nothing to do. Otherwise, fetch all updates
-            # between then and now.
-            missing_updates = cmd.token != current_token
-            while missing_updates:
-                logger.info(
-                    "Fetching replication rows for '%s' between %i and %i",
-                    stream_name,
-                    current_token,
-                    cmd.token,
-                )
-                (
-                    updates,
-                    current_token,
-                    missing_updates,
-                ) = await stream.get_updates_since(
-                    cmd.instance_name, current_token, cmd.token
-                )
+    async def _process_position(
+        self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+    ) -> None:
+        """Process a POSITION command
 
-                # TODO: add some tests for this
+        Called after the command has been popped off the queue of inbound commands
+        """
+        stream = self._streams[stream_name]
 
-                # Some streams return multiple rows with the same stream IDs,
-                # which need to be processed in batches.
+        # We're about to go and catch up with the stream, so remove from set
+        # of connected streams.
+        for streams in self._streams_by_connection.values():
+            streams.discard(stream_name)
 
-                for token, rows in _batch_updates(updates):
-                    await self.on_rdata(
-                        stream_name,
-                        cmd.instance_name,
-                        token,
-                        [stream.parse_row(row) for row in rows],
-                    )
+        # We clear the pending batches for the stream as the fetching of the
+        # missing updates below will fetch all rows in the batch.
+        self._pending_batches.pop(stream_name, [])
 
-            logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+        # Find where we previously streamed up to.
+        current_token = stream.current_token(cmd.instance_name)
 
-            # We've now caught up to position sent to us, notify handler.
-            await self._replication_data_handler.on_position(
-                cmd.stream_name, cmd.instance_name, cmd.token
+        # If the position token matches our current token then we're up to
+        # date and there's nothing to do. Otherwise, fetch all updates
+        # between then and now.
+        missing_updates = cmd.token != current_token
+        while missing_updates:
+            logger.info(
+                "Fetching replication rows for '%s' between %i and %i",
+                stream_name,
+                current_token,
+                cmd.token,
+            )
+            (updates, current_token, missing_updates) = await stream.get_updates_since(
+                cmd.instance_name, current_token, cmd.token
             )
 
-            self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+            # TODO: add some tests for this
 
-    async def on_REMOTE_SERVER_UP(
-        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
-    ):
+            # Some streams return multiple rows with the same stream IDs,
+            # which need to be processed in batches.
+
+            for token, rows in _batch_updates(updates):
+                await self.on_rdata(
+                    stream_name,
+                    cmd.instance_name,
+                    token,
+                    [stream.parse_row(row) for row in rows],
+                )
+
+        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+        # We've now caught up to position sent to us, notify handler.
+        await self._replication_data_handler.on_position(
+            cmd.stream_name, cmd.instance_name, cmd.token
+        )
+
+        self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+
+    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
@@ -505,7 +650,7 @@ class ReplicationCommandHandler:
         """Ack data for the federation stream. This allows the master to drop
         data stored purely in memory.
         """
-        self.send_command(FederationAckCommand(token))
+        self.send_command(FederationAckCommand(self._instance_name, token))
 
     def send_user_sync(
         self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int