summary refs log tree commit diff
path: root/synapse/replication/tcp/redis.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/redis.py')
-rw-r--r--synapse/replication/tcp/redis.py37
1 files changed, 17 insertions, 20 deletions
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index b5c533a607..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from inspect import isawaitable
 from typing import TYPE_CHECKING
 
 import txredisapi
@@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # remote instances.
         tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+        Awaitable).
 
         Args:
             cmd: received command
         """
-        handled = False
-
-        # First call any command handlers on this instance. These are for redis
-        # specific handling.
-        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(cmd)
-            handled = True
 
-        # Then call out to the handler.
         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(self, cmd)
-            handled = True
-
-        if not handled:
+        if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
+            return
+
+        res = cmd_func(self, cmd)
+
+        # the handler might be a coroutine: fire it off as a background process
+        # if so.
+
+        if isawaitable(res):
+            run_as_background_process(
+                "replication-" + cmd.get_logcontext_id(), lambda: res
+            )
 
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")