diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 58a871c6d9..e616b5e1c8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -267,27 +267,38 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC"
def __init__(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
):
self.instance_id = instance_id
self.user_id = user_id
+ self.device_id = device_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
- instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
+ device_id: Optional[str]
+ instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4)
+
+ if device_id == "None":
+ device_id = None
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(instance_id, user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms))
def to_line(self) -> str:
return " ".join(
(
self.instance_id,
self.user_id,
+ str(self.device_id),
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 92c5a55acc..d9045d7b73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -428,7 +428,11 @@ class ReplicationCommandHandler:
if self._is_presence_writer:
return self._presence_handler.update_external_syncs_row(
- cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ cmd.instance_id,
+ cmd.user_id,
+ cmd.device_id,
+ cmd.is_syncing,
+ cmd.last_sync_ms,
)
else:
return None
@@ -699,9 +703,9 @@ class ReplicationCommandHandler:
)
now = self._clock.time_msec()
- for user_id in currently_syncing:
+ for user_id, device_id in currently_syncing:
connection.send_command(
- UserSyncCommand(self._instance_id, user_id, True, now)
+ UserSyncCommand(self._instance_id, user_id, device_id, True, now)
)
def lost_connection(self, connection: IReplicationConnection) -> None:
@@ -753,11 +757,16 @@ class ReplicationCommandHandler:
self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
) -> None:
"""Poke the master that a user has started/stopped syncing."""
self.send_command(
- UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms)
)
def send_user_ip(
|