diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 50c68c86ce..2f841863ae 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -23,6 +23,7 @@ The methods that define policy are:
"""
import abc
import contextlib
+import itertools
import logging
from bisect import bisect
from contextlib import contextmanager
@@ -188,15 +189,17 @@ class BasePresenceHandler(abc.ABC):
"""
@abc.abstractmethod
- def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
- """Get an iterable of syncing users on this worker, to send to the presence handler
+ def get_currently_syncing_users_for_replication(
+ self,
+ ) -> Iterable[Tuple[str, Optional[str]]]:
+ """Get an iterable of syncing users and devices on this worker, to send to the presence handler
This is called when a replication connection is established. It should return
- a list of user ids, which are then sent as USER_SYNC commands to inform the
- process handling presence about those users.
+ a list of tuples of user ID & device ID, which are then sent as USER_SYNC commands
+ to inform the process handling presence about those users/devices.
Returns:
- An iterable of user_id strings.
+ An iterable of tuples of user ID and device ID.
"""
async def get_state(self, target_user: UserID) -> UserPresenceState:
@@ -284,7 +287,12 @@ class BasePresenceHandler(abc.ABC):
"""
async def update_external_syncs_row( # noqa: B027 (no-op by design)
- self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
+ self,
+ process_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ sync_time_msec: int,
) -> None:
"""Update the syncing users for an external process as a delta.
@@ -295,6 +303,7 @@ class BasePresenceHandler(abc.ABC):
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
user_id: The user who has started or stopped syncing
+ device_id: The user's device that has started or stopped syncing
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
@@ -425,16 +434,18 @@ class WorkerPresenceHandler(BasePresenceHandler):
hs.config.worker.writers.presence,
)
- # The number of ongoing syncs on this process, by user id.
+ # The number of ongoing syncs on this process, by (user ID, device ID).
# Empty if _presence_enabled is false.
- self._user_to_num_current_syncs: Dict[str, int] = {}
+ self._user_device_to_num_current_syncs: Dict[
+ Tuple[str, Optional[str]], int
+ ] = {}
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
- # user_id -> last_sync_ms. Lists the users that have stopped syncing but
- # we haven't notified the presence writer of that yet
- self.users_going_offline: Dict[str, int] = {}
+ # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped
+ # syncing but we haven't notified the presence writer of that yet
+ self._user_devices_going_offline: Dict[Tuple[str, Optional[str]], int] = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -457,39 +468,47 @@ class WorkerPresenceHandler(BasePresenceHandler):
ClearUserSyncsCommand(self.instance_id)
)
- def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None:
+ def send_user_sync(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
+ ) -> None:
if self._presence_enabled:
self.hs.get_replication_command_handler().send_user_sync(
- self.instance_id, user_id, is_syncing, last_sync_ms
+ self.instance_id, user_id, device_id, is_syncing, last_sync_ms
)
- def mark_as_coming_online(self, user_id: str) -> None:
+ def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None:
"""A user has started syncing. Send a UserSync to the presence writer,
unless they had recently stopped syncing.
"""
- going_offline = self.users_going_offline.pop(user_id, None)
+ going_offline = self._user_devices_going_offline.pop((user_id, device_id), None)
if not going_offline:
# Safe to skip because we haven't yet told the presence writer they
# were offline
- self.send_user_sync(user_id, True, self.clock.time_msec())
+ self.send_user_sync(user_id, device_id, True, self.clock.time_msec())
- def mark_as_going_offline(self, user_id: str) -> None:
+ def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None:
"""A user has stopped syncing. We wait before notifying the presence
writer as its likely they'll come back soon. This allows us to avoid
sending a stopped syncing immediately followed by a started syncing
notification to the presence writer
"""
- self.users_going_offline[user_id] = self.clock.time_msec()
+ self._user_devices_going_offline[(user_id, device_id)] = self.clock.time_msec()
def send_stop_syncing(self) -> None:
"""Check if there are any users who have stopped syncing a while ago and
haven't come back yet. If there are poke the presence writer about them.
"""
now = self.clock.time_msec()
- for user_id, last_sync_ms in list(self.users_going_offline.items()):
+ for (user_id, device_id), last_sync_ms in list(
+ self._user_devices_going_offline.items()
+ ):
if now - last_sync_ms > UPDATE_SYNCING_USERS_MS:
- self.users_going_offline.pop(user_id, None)
- self.send_user_sync(user_id, False, last_sync_ms)
+ self._user_devices_going_offline.pop((user_id, device_id), None)
+ self.send_user_sync(user_id, device_id, False, last_sync_ms)
async def user_syncing(
self,
@@ -515,23 +534,23 @@ class WorkerPresenceHandler(BasePresenceHandler):
is_sync=True,
)
- curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
- self._user_to_num_current_syncs[user_id] = curr_sync + 1
+ curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0)
+ self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1
# If this is the first in-flight sync, notify replication
- if self._user_to_num_current_syncs[user_id] == 1:
- self.mark_as_coming_online(user_id)
+ if self._user_device_to_num_current_syncs[(user_id, device_id)] == 1:
+ self.mark_as_coming_online(user_id, device_id)
def _end() -> None:
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
- if user_id in self._user_to_num_current_syncs:
- self._user_to_num_current_syncs[user_id] -= 1
+ if (user_id, device_id) in self._user_device_to_num_current_syncs:
+ self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1
# If there are no more in-flight syncs, notify replication
- if self._user_to_num_current_syncs[user_id] == 0:
- self.mark_as_going_offline(user_id)
+ if self._user_device_to_num_current_syncs[(user_id, device_id)] == 0:
+ self.mark_as_going_offline(user_id, device_id)
@contextlib.contextmanager
def _user_syncing() -> Generator[None, None, None]:
@@ -598,10 +617,12 @@ class WorkerPresenceHandler(BasePresenceHandler):
# If this is a federation sender, notify about presence updates.
await self.maybe_send_presence_to_interested_destinations(state_to_notify)
- def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+ def get_currently_syncing_users_for_replication(
+ self,
+ ) -> Iterable[Tuple[str, Optional[str]]]:
return [
- user_id
- for user_id, count in self._user_to_num_current_syncs.items()
+ user_id_device_id
+ for user_id_device_id, count in self._user_device_to_num_current_syncs.items()
if count > 0
]
@@ -723,17 +744,23 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
- self.user_to_num_current_syncs: Dict[str, int] = {}
+ self._user_device_to_num_current_syncs: Dict[
+ Tuple[str, Optional[str]], int
+ ] = {}
# Keeps track of the number of *ongoing* syncs on other processes.
+ #
# While any sync is ongoing on another process the user will never
# go offline.
+ #
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
- # Stored as a dict from process_id to set of user_id, and a dict of
- # process_id to millisecond timestamp last updated.
- self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
+ # Stored as a dict from process_id to set of (user_id, device_id), and
+ # a dict of process_id to millisecond timestamp last updated.
+ self.external_process_to_current_syncs: Dict[
+ str, Set[Tuple[str, Optional[str]]]
+ ] = {}
self.external_process_last_updated_ms: Dict[str, int] = {}
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
@@ -938,7 +965,10 @@ class PresenceHandler(BasePresenceHandler):
# that were syncing on that process to see if they need to be timed
# out.
users_to_check.update(
- self.external_process_to_current_syncs.pop(process_id, ())
+ user_id
+ for user_id, device_id in self.external_process_to_current_syncs.pop(
+ process_id, ()
+ )
)
self.external_process_last_updated_ms.pop(process_id)
@@ -951,11 +981,15 @@ class PresenceHandler(BasePresenceHandler):
syncing_user_ids = {
user_id
- for user_id, count in self.user_to_num_current_syncs.items()
+ for (user_id, _), count in self._user_device_to_num_current_syncs.items()
if count
}
- for user_ids in self.external_process_to_current_syncs.values():
- syncing_user_ids.update(user_ids)
+ syncing_user_ids.update(
+ user_id
+ for user_id, _ in itertools.chain(
+ *self.external_process_to_current_syncs.values()
+ )
+ )
changes = handle_timeouts(
states,
@@ -1013,8 +1047,8 @@ class PresenceHandler(BasePresenceHandler):
if not affect_presence or not self._presence_enabled:
return _NullContextManager()
- curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
- self.user_to_num_current_syncs[user_id] = curr_sync + 1
+ curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0)
+ self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1
# Note that this causes last_active_ts to be incremented which is not
# what the spec wants.
@@ -1027,7 +1061,7 @@ class PresenceHandler(BasePresenceHandler):
async def _end() -> None:
try:
- self.user_to_num_current_syncs[user_id] -= 1
+ self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1
prev_state = await self.current_state_for_user(user_id)
await self._update_states(
@@ -1049,12 +1083,19 @@ class PresenceHandler(BasePresenceHandler):
return _user_syncing()
- def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+ def get_currently_syncing_users_for_replication(
+ self,
+ ) -> Iterable[Tuple[str, Optional[str]]]:
# since we are the process handling presence, there is nothing to do here.
return []
async def update_external_syncs_row(
- self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
+ self,
+ process_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ sync_time_msec: int,
) -> None:
"""Update the syncing users for an external process as a delta.
@@ -1063,6 +1104,7 @@ class PresenceHandler(BasePresenceHandler):
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
user_id: The user who has started or stopped syncing
+ device_id: The user's device that has started or stopped syncing
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
@@ -1073,26 +1115,27 @@ class PresenceHandler(BasePresenceHandler):
process_id, set()
)
- # USER_SYNC is sent when a user starts or stops syncing on a remote
- # process. (But only for the initial and last device.)
+ # USER_SYNC is sent when a user's device starts or stops syncing on
+ # a remote # process. (But only for the initial and last sync for that
+ # device.)
#
- # When a user *starts* syncing it also calls set_state(...) which
+ # When a device *starts* syncing it also calls set_state(...) which
# will update the state, last_active_ts, and last_user_sync_ts.
- # Simply ensure the user is tracked as syncing in this case.
+ # Simply ensure the user & device is tracked as syncing in this case.
#
- # When a user *stops* syncing, update the last_user_sync_ts and mark
+ # When a device *stops* syncing, update the last_user_sync_ts and mark
# them as no longer syncing. Note this doesn't quite match the
# monolith behaviour, which updates last_user_sync_ts at the end of
# every sync, not just the last in-flight sync.
- if is_syncing and user_id not in process_presence:
- process_presence.add(user_id)
- elif not is_syncing and user_id in process_presence:
+ if is_syncing and (user_id, device_id) not in process_presence:
+ process_presence.add((user_id, device_id))
+ elif not is_syncing and (user_id, device_id) in process_presence:
new_state = prev_state.copy_and_replace(
last_user_sync_ts=sync_time_msec
)
await self._update_states([new_state])
- process_presence.discard(user_id)
+ process_presence.discard((user_id, device_id))
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
@@ -1106,7 +1149,9 @@ class PresenceHandler(BasePresenceHandler):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
- prev_states = await self.current_state_for_users(process_presence)
+ prev_states = await self.current_state_for_users(
+ {user_id for user_id, device_id in process_presence}
+ )
time_now_ms = self.clock.time_msec()
await self._update_states(
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 58a871c6d9..e616b5e1c8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -267,27 +267,38 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC"
def __init__(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
):
self.instance_id = instance_id
self.user_id = user_id
+ self.device_id = device_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
- instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
+ device_id: Optional[str]
+ instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4)
+
+ if device_id == "None":
+ device_id = None
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(instance_id, user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms))
def to_line(self) -> str:
return " ".join(
(
self.instance_id,
self.user_id,
+ str(self.device_id),
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 92c5a55acc..d9045d7b73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -428,7 +428,11 @@ class ReplicationCommandHandler:
if self._is_presence_writer:
return self._presence_handler.update_external_syncs_row(
- cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ cmd.instance_id,
+ cmd.user_id,
+ cmd.device_id,
+ cmd.is_syncing,
+ cmd.last_sync_ms,
)
else:
return None
@@ -699,9 +703,9 @@ class ReplicationCommandHandler:
)
now = self._clock.time_msec()
- for user_id in currently_syncing:
+ for user_id, device_id in currently_syncing:
connection.send_command(
- UserSyncCommand(self._instance_id, user_id, True, now)
+ UserSyncCommand(self._instance_id, user_id, device_id, True, now)
)
def lost_connection(self, connection: IReplicationConnection) -> None:
@@ -753,11 +757,16 @@ class ReplicationCommandHandler:
self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
- self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ self,
+ instance_id: str,
+ user_id: str,
+ device_id: Optional[str],
+ is_syncing: bool,
+ last_sync_ms: int,
) -> None:
"""Poke the master that a user has started/stopped syncing."""
self.send_command(
- UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms)
)
def send_user_ip(
|