summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py45
1 files changed, 28 insertions, 17 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 23191e3218..0350923898 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
 import fcntl
 import logging
 import struct
+from inspect import isawaitable
 from typing import TYPE_CHECKING, List
 
 from prometheus_client import Counter
@@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
     command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+    `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+    if so, that will get run as a background process.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
-        self._logging_context = BackgroundProcessLoggingContext(
-            "replication_command_handler-%s" % self.conn_id
-        )
+        ctx_name = "replication-conn-%s" % self.conn_id
+        self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+        self._logging_context.request = ctx_name
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
         First calls `self.on_<COMMAND>` if it exists, then calls
-        `self.command_handler.on_<COMMAND>` if it exists. This allows for
-        protocol level handling of commands (e.g. PINGs), before delegating to
-        the handler.
+        `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+        return an Awaitable).
+
+        This allows for protocol level handling of commands (e.g. PINGs), before
+        delegating to the handler.
 
         Args:
             cmd: received command
@@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # specific handling.
         cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            cmd_func(cmd)
             handled = True
 
         # Then call out to the handler.
         cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(self, cmd)
+            res = cmd_func(self, cmd)
+
+            # the handler might be a coroutine: fire it off as a background process
+            # if so.
+
+            if isawaitable(res):
+                run_as_background_process(
+                    "replication-" + cmd.get_logcontext_id(), lambda: res
+                )
+
             handled = True
 
         if not handled:
@@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    async def on_PING(self, line):
+    def on_PING(self, line):
         self.received_ping = True
 
-    async def on_ERROR(self, cmd):
+    def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.send_command(ServerCommand(self.server_name))
         super().connectionMade()
 
-    async def on_NAME(self, cmd):
+    def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
@@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # Once we've connected subscribe to the necessary streams
         self.replicate()
 
-    async def on_SERVER(self, cmd):
+    def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")