summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/presence.py21
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/commands.py29
-rw-r--r--synapse/replication/tcp/handler.py37
5 files changed, 74 insertions, 23 deletions
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 73f3de3642..209833d287 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -62,7 +62,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
 
     NAME = "multi_user_device_resync"
     PATH_ARGS = ()
-    CACHE = False
+    CACHE = True
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index db16aac9c2..6c9e79fb07 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -51,14 +51,14 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
         self._presence_handler = hs.get_presence_handler()
 
     @staticmethod
-    async def _serialize_payload(user_id: str) -> JsonDict:  # type: ignore[override]
-        return {}
+    async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict:  # type: ignore[override]
+        return {"device_id": device_id}
 
     async def _handle_request(  # type: ignore[override]
         self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
         await self._presence_handler.bump_presence_active_time(
-            UserID.from_string(user_id)
+            UserID.from_string(user_id), content.get("device_id")
         )
 
         return (200, {})
@@ -73,8 +73,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
 
         {
             "state": { ... },
-            "ignore_status_msg": false,
-            "force_notify": false
+            "force_notify": false,
+            "is_sync": false
         }
 
         200 OK
@@ -95,14 +95,16 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
     @staticmethod
     async def _serialize_payload(  # type: ignore[override]
         user_id: str,
+        device_id: Optional[str],
         state: JsonDict,
-        ignore_status_msg: bool = False,
         force_notify: bool = False,
+        is_sync: bool = False,
     ) -> JsonDict:
         return {
+            "device_id": device_id,
             "state": state,
-            "ignore_status_msg": ignore_status_msg,
             "force_notify": force_notify,
+            "is_sync": is_sync,
         }
 
     async def _handle_request(  # type: ignore[override]
@@ -110,9 +112,10 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
     ) -> Tuple[int, JsonDict]:
         await self._presence_handler.set_state(
             UserID.from_string(user_id),
+            content.get("device_id"),
             content["state"],
-            content["ignore_status_msg"],
             content["force_notify"],
+            content.get("is_sync", False),
         )
 
         return (200, {})
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 078c8d7074..ca8a76f77c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -410,7 +410,7 @@ class FederationSenderHandler:
         # The federation stream contains things that we want to send out, e.g.
         # presence, typing, etc.
         if stream_name == "federation":
-            send_queue.process_rows_for_federation(self.federation_sender, rows)
+            await send_queue.process_rows_for_federation(self.federation_sender, rows)
             await self.update_token(token)
 
         # ... and when new receipts happen
@@ -427,16 +427,14 @@ class FederationSenderHandler:
                 for row in rows
                 if not row.entity.startswith("@") and not row.is_signature
             }
-            for host in hosts:
-                self.federation_sender.send_device_messages(host, immediate=False)
+            await self.federation_sender.send_device_messages(hosts, immediate=False)
 
         elif stream_name == ToDeviceStream.NAME:
             # The to_device stream includes stuff to be pushed to both local
             # clients and remote servers, so we ignore entities that start with
             # '@' (since they'll be local users rather than destinations).
             hosts = {row.entity for row in rows if not row.entity.startswith("@")}
-            for host in hosts:
-                self.federation_sender.send_device_messages(host)
+            await self.federation_sender.send_device_messages(hosts)
 
     async def _on_new_receipts(
         self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 10f5c98ff8..e616b5e1c8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -267,27 +267,38 @@ class UserSyncCommand(Command):
     NAME = "USER_SYNC"
 
     def __init__(
-        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+        self,
+        instance_id: str,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        last_sync_ms: int,
     ):
         self.instance_id = instance_id
         self.user_id = user_id
+        self.device_id = device_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
-        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
+        device_id: Optional[str]
+        instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4)
+
+        if device_id == "None":
+            device_id = None
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms))
 
     def to_line(self) -> str:
         return " ".join(
             (
                 self.instance_id,
                 self.user_id,
+                str(self.device_id),
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
             )
@@ -452,6 +463,17 @@ class LockReleasedCommand(Command):
         return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
 
 
+class NewActiveTaskCommand(_SimpleCommand):
+    """Sent to inform instance handling background tasks that a new active task is available to run.
+
+    Format::
+
+        NEW_ACTIVE_TASK "<task_id>"
+    """
+
+    NAME = "NEW_ACTIVE_TASK"
+
+
 _COMMANDS: Tuple[Type[Command], ...] = (
     ServerCommand,
     RdataCommand,
@@ -466,6 +488,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
     LockReleasedCommand,
+    NewActiveTaskCommand,
 )
 
 # Map of command name to command type.
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 38adcbe1d0..d9045d7b73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import (
     Command,
     FederationAckCommand,
     LockReleasedCommand,
+    NewActiveTaskCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -238,6 +239,10 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
+        self._task_scheduler = None
+        if hs.config.worker.run_background_tasks:
+            self._task_scheduler = hs.get_task_scheduler()
+
         if hs.config.redis.redis_enabled:
             # If we're using Redis, it's the background worker that should
             # receive USER_IP commands and store the relevant client IPs.
@@ -423,7 +428,11 @@ class ReplicationCommandHandler:
 
         if self._is_presence_writer:
             return self._presence_handler.update_external_syncs_row(
-                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+                cmd.instance_id,
+                cmd.user_id,
+                cmd.device_id,
+                cmd.is_syncing,
+                cmd.last_sync_ms,
             )
         else:
             return None
@@ -663,6 +672,15 @@ class ReplicationCommandHandler:
             cmd.instance_name, cmd.lock_name, cmd.lock_key
         )
 
+    async def on_NEW_ACTIVE_TASK(
+        self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
+    ) -> None:
+        """Called when get a new NEW_ACTIVE_TASK command."""
+        if self._task_scheduler:
+            task = await self._task_scheduler.get_task(cmd.data)
+            if task:
+                await self._task_scheduler._launch_task(task)
+
     def new_connection(self, connection: IReplicationConnection) -> None:
         """Called when we have a new connection."""
         self._connections.append(connection)
@@ -685,9 +703,9 @@ class ReplicationCommandHandler:
         )
 
         now = self._clock.time_msec()
-        for user_id in currently_syncing:
+        for user_id, device_id in currently_syncing:
             connection.send_command(
-                UserSyncCommand(self._instance_id, user_id, True, now)
+                UserSyncCommand(self._instance_id, user_id, device_id, True, now)
             )
 
     def lost_connection(self, connection: IReplicationConnection) -> None:
@@ -739,11 +757,16 @@ class ReplicationCommandHandler:
         self.send_command(FederationAckCommand(self._instance_name, token))
 
     def send_user_sync(
-        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+        self,
+        instance_id: str,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        last_sync_ms: int,
     ) -> None:
         """Poke the master that a user has started/stopped syncing."""
         self.send_command(
-            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+            UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms)
         )
 
     def send_user_ip(
@@ -776,6 +799,10 @@ class ReplicationCommandHandler:
         if instance_name == self._instance_name:
             self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
 
+    def send_new_active_task(self, task_id: str) -> None:
+        """Called when a new task has been scheduled for immediate launch and is ACTIVE."""
+        self.send_command(NewActiveTaskCommand(task_id))
+
 
 UpdateToken = TypeVar("UpdateToken")
 UpdateRow = TypeVar("UpdateRow")