diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 73f3de3642..209833d287 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -62,7 +62,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
NAME = "multi_user_device_resync"
PATH_ARGS = ()
- CACHE = False
+ CACHE = True
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index db16aac9c2..6c9e79fb07 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
@@ -51,14 +51,14 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
- return {}
+ async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict: # type: ignore[override]
+ return {"device_id": device_id}
async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict, user_id: str
) -> Tuple[int, JsonDict]:
await self._presence_handler.bump_presence_active_time(
- UserID.from_string(user_id)
+ UserID.from_string(user_id), content.get("device_id")
)
return (200, {})
@@ -73,8 +73,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
{
"state": { ... },
- "ignore_status_msg": false,
- "force_notify": false
+ "force_notify": false,
+ "is_sync": false
}
200 OK
@@ -95,14 +95,16 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
@staticmethod
async def _serialize_payload( # type: ignore[override]
user_id: str,
+ device_id: Optional[str],
state: JsonDict,
- ignore_status_msg: bool = False,
force_notify: bool = False,
+ is_sync: bool = False,
) -> JsonDict:
return {
+ "device_id": device_id,
"state": state,
- "ignore_status_msg": ignore_status_msg,
"force_notify": force_notify,
+ "is_sync": is_sync,
}
async def _handle_request( # type: ignore[override]
@@ -110,9 +112,10 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
) -> Tuple[int, JsonDict]:
await self._presence_handler.set_state(
UserID.from_string(user_id),
+ content.get("device_id"),
content["state"],
- content["ignore_status_msg"],
content["force_notify"],
+ content.get("is_sync", False),
)
return (200, {})
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3b88dc68ea..51285e6d33 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -422,7 +422,7 @@ class FederationSenderHandler:
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
if stream_name == "federation":
- send_queue.process_rows_for_federation(self.federation_sender, rows)
+ await send_queue.process_rows_for_federation(self.federation_sender, rows)
await self.update_token(token)
# ... and when new receipts happen
@@ -439,16 +439,14 @@ class FederationSenderHandler:
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
- for host in hosts:
- self.federation_sender.send_device_messages(host, immediate=False)
+ await self.federation_sender.send_device_messages(hosts, immediate=False)
elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local
# clients and remote servers, so we ignore entities that start with
# '@' (since they'll be local users rather than destinations).
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
- for host in hosts:
- self.federation_sender.send_device_messages(host)
+ await self.federation_sender.send_device_messages(hosts)
async def _on_new_receipts(
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 10f5c98ff8..e616b5e1c8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -267,27 +267,38 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC"
def __init__(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
):
self.instance_id = instance_id
self.user_id = user_id
+ self.device_id = device_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
- instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
+ device_id: Optional[str]
+ instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4)
+
+ if device_id == "None":
+ device_id = None
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(instance_id, user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms))
def to_line(self) -> str:
return " ".join(
(
self.instance_id,
self.user_id,
+ str(self.device_id),
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
)
@@ -452,6 +463,17 @@ class LockReleasedCommand(Command):
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
+class NewActiveTaskCommand(_SimpleCommand):
+ """Sent to inform instance handling background tasks that a new active task is available to run.
+
+ Format::
+
+ NEW_ACTIVE_TASK "<task_id>"
+ """
+
+ NAME = "NEW_ACTIVE_TASK"
+
+
_COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand,
RdataCommand,
@@ -466,6 +488,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
RemoteServerUpCommand,
ClearUserSyncsCommand,
LockReleasedCommand,
+ NewActiveTaskCommand,
)
# Map of command name to command type.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 38adcbe1d0..d9045d7b73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import (
Command,
FederationAckCommand,
LockReleasedCommand,
+ NewActiveTaskCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -238,6 +239,10 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ self._task_scheduler = None
+ if hs.config.worker.run_background_tasks:
+ self._task_scheduler = hs.get_task_scheduler()
+
if hs.config.redis.redis_enabled:
# If we're using Redis, it's the background worker that should
# receive USER_IP commands and store the relevant client IPs.
@@ -423,7 +428,11 @@ class ReplicationCommandHandler:
if self._is_presence_writer:
return self._presence_handler.update_external_syncs_row(
- cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ cmd.instance_id,
+ cmd.user_id,
+ cmd.device_id,
+ cmd.is_syncing,
+ cmd.last_sync_ms,
)
else:
return None
@@ -663,6 +672,15 @@ class ReplicationCommandHandler:
cmd.instance_name, cmd.lock_name, cmd.lock_key
)
+ async def on_NEW_ACTIVE_TASK(
+ self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
+ ) -> None:
+ """Called when get a new NEW_ACTIVE_TASK command."""
+ if self._task_scheduler:
+ task = await self._task_scheduler.get_task(cmd.data)
+ if task:
+ await self._task_scheduler._launch_task(task)
+
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -685,9 +703,9 @@ class ReplicationCommandHandler:
)
now = self._clock.time_msec()
- for user_id in currently_syncing:
+ for user_id, device_id in currently_syncing:
connection.send_command(
- UserSyncCommand(self._instance_id, user_id, True, now)
+ UserSyncCommand(self._instance_id, user_id, device_id, True, now)
)
def lost_connection(self, connection: IReplicationConnection) -> None:
@@ -739,11 +757,16 @@ class ReplicationCommandHandler:
self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
) -> None:
"""Poke the master that a user has started/stopped syncing."""
self.send_command(
- UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms)
)
def send_user_ip(
@@ -776,6 +799,10 @@ class ReplicationCommandHandler:
if instance_name == self._instance_name:
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
+ def send_new_active_task(self, task_id: str) -> None:
+ """Called when a new task has been scheduled for immediate launch and is ACTIVE."""
+ self.send_command(NewActiveTaskCommand(task_id))
+
UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
|