diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..0b0d204e64 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
+from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@@ -57,8 +58,12 @@ from prometheus_client import Counter
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ BackgroundProcessLoggingContext,
+ run_as_background_process,
+)
from synapse.replication.tcp.commands import (
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
@@ -108,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5
PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
-class ConnectionStates(object):
+class ConnectionStates:
CONNECTING = "connecting"
ESTABLISHED = "established"
PAUSED = "paused"
@@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+ `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+ if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
+ # a logcontext which we use for processing incoming commands. We declare it as a
+ # background process so that the CPU stats get reported to prometheus.
+ ctx_name = "replication-conn-%s" % self.conn_id
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+ self._logging_context.request = ctx_name
+
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def lineReceived(self, line: bytes):
"""Called when we've received a line
"""
+ with PreserveLoggingContext(self._logging_context):
+ self._parse_and_dispatch_line(line)
+
+ def _parse_and_dispatch_line(self, line: bytes):
if line.strip() == "":
# Ignore blank lines
return
@@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
- `self.command_handler.on_<COMMAND>` if it exists. This allows for
- protocol level handling of commands (e.g. PINGs), before delegating to
- the handler.
+ `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+ return an Awaitable).
+
+ This allows for protocol level handling of commands (e.g. PINGs), before
+ delegating to the handler.
Args:
cmd: received command
@@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(self, cmd)
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
+
handled = True
if not handled:
@@ -317,7 +342,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
- logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
@@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- async def on_PING(self, line):
+ def on_PING(self, line):
self.received_ping = True
- async def on_ERROR(self, cmd):
+ def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
self.transport.unregisterProducer()
+ # mark the logging context as finished
+ self._logging_context.__exit__(None, None, None)
+
def __str__(self):
addr = None
if self.transport:
@@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
- async def on_NAME(self, cmd):
+ def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
- async def on_SERVER(self, cmd):
+ def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
|