diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,12 +14,16 @@
# limitations under the License.
import logging
+from inspect import isawaitable
from typing import TYPE_CHECKING
import txredisapi
-from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import (
+ BackgroundProcessLoggingContext,
+ run_as_background_process,
+)
from synapse.replication.tcp.commands import (
Command,
ReplicateCommand,
@@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # a logcontext which we use for processing incoming commands. We declare it as a
+ # background process so that the CPU stats get reported to prometheus.
+ self._logging_context = BackgroundProcessLoggingContext(
+ "replication_command_handler"
+ )
+
def connectionMade(self):
logger.info("Connected to redis")
super().connectionMade()
@@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
+ with PreserveLoggingContext(self._logging_context):
+ self._parse_and_dispatch_message(message)
+ def _parse_and_dispatch_message(self, message: str):
if message.strip() == "":
# Ignore blank lines
return
@@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+ Awaitable).
Args:
cmd: received command
"""
- handled = False
-
- # First call any command handlers on this instance. These are for redis
- # specific handling.
- cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(cmd)
- handled = True
- # Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(self, cmd)
- handled = True
-
- if not handled:
+ if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
+ return
+
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
+ # mark the logging context as finished
+ self._logging_context.__exit__(None, None, None)
+
def send_command(self, cmd: Command):
"""Send a command if connection has been established.
@@ -177,7 +192,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
- send outbound commands (this is seperate to the redis connection
+ send outbound commands (this is separate to the redis connection
used to subscribe).
"""
|