diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 38adcbe1d0..d9045d7b73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import (
Command,
FederationAckCommand,
LockReleasedCommand,
+ NewActiveTaskCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -238,6 +239,10 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ self._task_scheduler = None
+ if hs.config.worker.run_background_tasks:
+ self._task_scheduler = hs.get_task_scheduler()
+
if hs.config.redis.redis_enabled:
# If we're using Redis, it's the background worker that should
# receive USER_IP commands and store the relevant client IPs.
@@ -423,7 +428,11 @@ class ReplicationCommandHandler:
if self._is_presence_writer:
return self._presence_handler.update_external_syncs_row(
- cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ cmd.instance_id,
+ cmd.user_id,
+ cmd.device_id,
+ cmd.is_syncing,
+ cmd.last_sync_ms,
)
else:
return None
@@ -663,6 +672,15 @@ class ReplicationCommandHandler:
cmd.instance_name, cmd.lock_name, cmd.lock_key
)
+ async def on_NEW_ACTIVE_TASK(
+ self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
+ ) -> None:
+ """Called when get a new NEW_ACTIVE_TASK command."""
+ if self._task_scheduler:
+ task = await self._task_scheduler.get_task(cmd.data)
+ if task:
+ await self._task_scheduler._launch_task(task)
+
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -685,9 +703,9 @@ class ReplicationCommandHandler:
)
now = self._clock.time_msec()
- for user_id in currently_syncing:
+ for user_id, device_id in currently_syncing:
connection.send_command(
- UserSyncCommand(self._instance_id, user_id, True, now)
+ UserSyncCommand(self._instance_id, user_id, device_id, True, now)
)
def lost_connection(self, connection: IReplicationConnection) -> None:
@@ -739,11 +757,16 @@ class ReplicationCommandHandler:
self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
) -> None:
"""Poke the master that a user has started/stopped syncing."""
self.send_command(
- UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms)
)
def send_user_ip(
@@ -776,6 +799,10 @@ class ReplicationCommandHandler:
if instance_name == self._instance_name:
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
+ def send_new_active_task(self, task_id: str) -> None:
+ """Called when a new task has been scheduled for immediate launch and is ACTIVE."""
+ self.send_command(NewActiveTaskCommand(task_id))
+
UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
|