summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-07-27 18:54:43 +0100
committerGitHub <noreply@github.com>2020-07-27 18:54:43 +0100
commitf57b99af22de874b11f44ef32c1f1425ec1344b9 (patch)
treed11bbffa9546fd70152a4a75dfad0a0c221f8bbe /synapse/replication/tcp
parentupdate changelog (diff)
downloadsynapse-f57b99af22de874b11f44ef32c1f1425ec1344b9.tar.xz
Handle replication commands synchronously where possible (#7876)
Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them.
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/handler.py115
-rw-r--r--synapse/replication/tcp/protocol.py45
-rw-r--r--synapse/replication/tcp/redis.py37
3 files changed, 111 insertions, 86 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1de590bba2..1c303f3a46 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -16,6 +16,7 @@
 import logging
 from typing import (
     Any,
+    Awaitable,
     Dict,
     Iterable,
     Iterator,
@@ -33,6 +34,7 @@ from typing_extensions import Deque
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
 from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
@@ -152,7 +154,7 @@ class ReplicationCommandHandler:
         # When POSITION or RDATA commands arrive, we stick them in a queue and process
         # them in order in a separate background process.
 
-        # the streams which are currently being processed by _unsafe_process_stream
+        # the streams which are currently being processed by _unsafe_process_queue
         self._processing_streams = set()  # type: Set[str]
 
         # for each stream, a queue of commands that are awaiting processing, and the
@@ -185,7 +187,7 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
-    async def _add_command_to_stream_queue(
+    def _add_command_to_stream_queue(
         self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
         """Queue the given received command for processing
@@ -199,33 +201,34 @@ class ReplicationCommandHandler:
             logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
             return
 
-        # if we're already processing this stream, stick the new command in the
-        # queue, and we're done.
+        queue.append((cmd, conn))
+
+        # if we're already processing this stream, there's nothing more to do:
+        # the new entry on the queue will get picked up in due course
         if stream_name in self._processing_streams:
-            queue.append((cmd, conn))
             return
 
-        # otherwise, process the new command.
+        # fire off a background process to start processing the queue.
+        run_as_background_process(
+            "process-replication-data", self._unsafe_process_queue, stream_name
+        )
 
-        # arguably we should start off a new background process here, but nothing
-        # will be too upset if we don't return for ages, so let's save the overhead
-        # and use the existing logcontext.
+    async def _unsafe_process_queue(self, stream_name: str):
+        """Processes the command queue for the given stream, until it is empty
+
+        Does not check if there is already a thread processing the queue, hence "unsafe"
+        """
+        assert stream_name not in self._processing_streams
 
         self._processing_streams.add(stream_name)
         try:
-            # might as well skip the queue for this one, since it must be empty
-            assert not queue
-            await self._process_command(cmd, conn, stream_name)
-
-            # now process any other commands that have built up while we were
-            # dealing with that one.
+            queue = self._command_queues_by_stream.get(stream_name)
             while queue:
                 cmd, conn = queue.popleft()
                 try:
                     await self._process_command(cmd, conn, stream_name)
                 except Exception:
                     logger.exception("Failed to handle command %s", cmd)
-
         finally:
             self._processing_streams.discard(stream_name)
 
@@ -299,7 +302,7 @@ class ReplicationCommandHandler:
         """
         return self._streams_to_replicate
 
-    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
     def send_positions_to_connection(self, conn: AbstractConnection):
@@ -318,57 +321,73 @@ class ReplicationCommandHandler:
                 )
             )
 
-    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+    def on_USER_SYNC(
+        self, conn: AbstractConnection, cmd: UserSyncCommand
+    ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
         if self._is_master:
-            await self._presence_handler.update_external_syncs_row(
+            return self._presence_handler.update_external_syncs_row(
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
+        else:
+            return None
 
-    async def on_CLEAR_USER_SYNC(
+    def on_CLEAR_USER_SYNC(
         self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         if self._is_master:
-            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+            return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+        else:
+            return None
 
-    async def on_FEDERATION_ACK(
-        self, conn: AbstractConnection, cmd: FederationAckCommand
-    ):
+    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
-    async def on_REMOVE_PUSHER(
+    def on_REMOVE_PUSHER(
         self, conn: AbstractConnection, cmd: RemovePusherCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         remove_pusher_counter.inc()
 
         if self._is_master:
-            await self._store.delete_pusher_by_app_id_pushkey_user_id(
-                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
-            )
+            return self._handle_remove_pusher(cmd)
+        else:
+            return None
+
+    async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+        await self._store.delete_pusher_by_app_id_pushkey_user_id(
+            app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+        )
 
-            self._notifier.on_new_replication_data()
+        self._notifier.on_new_replication_data()
 
-    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+    def on_USER_IP(
+        self, conn: AbstractConnection, cmd: UserIpCommand
+    ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
         if self._is_master:
-            await self._store.insert_client_ip(
-                cmd.user_id,
-                cmd.access_token,
-                cmd.ip,
-                cmd.user_agent,
-                cmd.device_id,
-                cmd.last_seen,
-            )
+            return self._handle_user_ip(cmd)
+        else:
+            return None
+
+    async def _handle_user_ip(self, cmd: UserIpCommand):
+        await self._store.insert_client_ip(
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
+            cmd.last_seen,
+        )
 
-        if self._server_notices_sender:
-            await self._server_notices_sender.on_user_ip(cmd.user_id)
+        assert self._server_notices_sender is not None
+        await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -382,7 +401,7 @@ class ReplicationCommandHandler:
         #   2. so we don't race with getting a POSITION command and fetching
         #      missing RDATA.
 
-        await self._add_command_to_stream_queue(conn, cmd)
+        self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_rdata(
         self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@@ -459,14 +478,14 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
 
         logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
 
-        await self._add_command_to_stream_queue(conn, cmd)
+        self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
         self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@@ -526,9 +545,7 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
-    async def on_REMOTE_SERVER_UP(
-        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
-    ):
+    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 23191e3218..0350923898 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
 import fcntl
 import logging
 import struct
+from inspect import isawaitable
 from typing import TYPE_CHECKING, List
 
 from prometheus_client import Counter
@@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
     command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+    `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+    if so, that will get run as a background process.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
-        self._logging_context = BackgroundProcessLoggingContext(
-            "replication_command_handler-%s" % self.conn_id
-        )
+        ctx_name = "replication-conn-%s" % self.conn_id
+        self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+        self._logging_context.request = ctx_name
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
         First calls `self.on_<COMMAND>` if it exists, then calls
-        `self.command_handler.on_<COMMAND>` if it exists. This allows for
-        protocol level handling of commands (e.g. PINGs), before delegating to
-        the handler.
+        `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+        return an Awaitable).
+
+        This allows for protocol level handling of commands (e.g. PINGs), before
+        delegating to the handler.
 
         Args:
             cmd: received command
@@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # specific handling.
         cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            cmd_func(cmd)
             handled = True
 
         # Then call out to the handler.
         cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(self, cmd)
+            res = cmd_func(self, cmd)
+
+            # the handler might be a coroutine: fire it off as a background process
+            # if so.
+
+            if isawaitable(res):
+                run_as_background_process(
+                    "replication-" + cmd.get_logcontext_id(), lambda: res
+                )
+
             handled = True
 
         if not handled:
@@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    async def on_PING(self, line):
+    def on_PING(self, line):
         self.received_ping = True
 
-    async def on_ERROR(self, cmd):
+    def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.send_command(ServerCommand(self.server_name))
         super().connectionMade()
 
-    async def on_NAME(self, cmd):
+    def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
@@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # Once we've connected subscribe to the necessary streams
         self.replicate()
 
-    async def on_SERVER(self, cmd):
+    def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index b5c533a607..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from inspect import isawaitable
 from typing import TYPE_CHECKING
 
 import txredisapi
@@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # remote instances.
         tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+        Awaitable).
 
         Args:
             cmd: received command
         """
-        handled = False
-
-        # First call any command handlers on this instance. These are for redis
-        # specific handling.
-        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(cmd)
-            handled = True
 
-        # Then call out to the handler.
         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(self, cmd)
-            handled = True
-
-        if not handled:
+        if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
+            return
+
+        res = cmd_func(self, cmd)
+
+        # the handler might be a coroutine: fire it off as a background process
+        # if so.
+
+        if isawaitable(res):
+            run_as_background_process(
+                "replication-" + cmd.get_logcontext_id(), lambda: res
+            )
 
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")