summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/handler.py18
2 files changed, 30 insertions, 0 deletions
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 10f5c98ff8..58a871c6d9 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -452,6 +452,17 @@ class LockReleasedCommand(Command):
         return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
 
 
+class NewActiveTaskCommand(_SimpleCommand):
+    """Sent to inform instance handling background tasks that a new active task is available to run.
+
+    Format::
+
+        NEW_ACTIVE_TASK "<task_id>"
+    """
+
+    NAME = "NEW_ACTIVE_TASK"
+
+
 _COMMANDS: Tuple[Type[Command], ...] = (
     ServerCommand,
     RdataCommand,
@@ -466,6 +477,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
     LockReleasedCommand,
+    NewActiveTaskCommand,
 )
 
 # Map of command name to command type.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 38adcbe1d0..92c5a55acc 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import (
     Command,
     FederationAckCommand,
     LockReleasedCommand,
+    NewActiveTaskCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -238,6 +239,10 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
+        self._task_scheduler = None
+        if hs.config.worker.run_background_tasks:
+            self._task_scheduler = hs.get_task_scheduler()
+
         if hs.config.redis.redis_enabled:
             # If we're using Redis, it's the background worker that should
             # receive USER_IP commands and store the relevant client IPs.
@@ -663,6 +668,15 @@ class ReplicationCommandHandler:
             cmd.instance_name, cmd.lock_name, cmd.lock_key
         )
 
+    async def on_NEW_ACTIVE_TASK(
+        self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
+    ) -> None:
+        """Called when get a new NEW_ACTIVE_TASK command."""
+        if self._task_scheduler:
+            task = await self._task_scheduler.get_task(cmd.data)
+            if task:
+                await self._task_scheduler._launch_task(task)
+
     def new_connection(self, connection: IReplicationConnection) -> None:
         """Called when we have a new connection."""
         self._connections.append(connection)
@@ -776,6 +790,10 @@ class ReplicationCommandHandler:
         if instance_name == self._instance_name:
             self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
 
+    def send_new_active_task(self, task_id: str) -> None:
+        """Called when a new task has been scheduled for immediate launch and is ACTIVE."""
+        self.send_command(NewActiveTaskCommand(task_id))
+
 
 UpdateToken = TypeVar("UpdateToken")
 UpdateRow = TypeVar("UpdateRow")