summary refs log tree commit diff
path: root/synapse/replication/tcp/redis.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/redis.py')
-rw-r--r--synapse/replication/tcp/redis.py61
1 files changed, 38 insertions, 23 deletions
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,12 +14,16 @@
 # limitations under the License.
 
 import logging
+from inspect import isawaitable
 from typing import TYPE_CHECKING
 
 import txredisapi
 
-from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import (
+    BackgroundProcessLoggingContext,
+    run_as_background_process,
+)
 from synapse.replication.tcp.commands import (
     Command,
     ReplicateCommand,
@@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     stream_name = None  # type: str
     outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # a logcontext which we use for processing incoming commands. We declare it as a
+        # background process so that the CPU stats get reported to prometheus.
+        self._logging_context = BackgroundProcessLoggingContext(
+            "replication_command_handler"
+        )
+
     def connectionMade(self):
         logger.info("Connected to redis")
         super().connectionMade()
@@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     def messageReceived(self, pattern: str, channel: str, message: str):
         """Received a message from redis.
         """
+        with PreserveLoggingContext(self._logging_context):
+            self._parse_and_dispatch_message(message)
 
+    def _parse_and_dispatch_message(self, message: str):
         if message.strip() == "":
             # Ignore blank lines
             return
@@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # remote instances.
         tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+        Awaitable).
 
         Args:
             cmd: received command
         """
-        handled = False
-
-        # First call any command handlers on this instance. These are for redis
-        # specific handling.
-        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(cmd)
-            handled = True
 
-        # Then call out to the handler.
         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(self, cmd)
-            handled = True
-
-        if not handled:
+        if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
+            return
+
+        res = cmd_func(self, cmd)
+
+        # the handler might be a coroutine: fire it off as a background process
+        # if so.
+
+        if isawaitable(res):
+            run_as_background_process(
+                "replication-" + cmd.get_logcontext_id(), lambda: res
+            )
 
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
         self.handler.lost_connection(self)
 
+        # mark the logging context as finished
+        self._logging_context.__exit__(None, None, None)
+
     def send_command(self, cmd: Command):
         """Send a command if connection has been established.
 
@@ -177,7 +192,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
     Args:
         hs
         outbound_redis_connection: A connection to redis that will be used to
-            send outbound commands (this is seperate to the redis connection
+            send outbound commands (this is separate to the redis connection
             used to subscribe).
     """