summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/presence.py155
-rw-r--r--synapse/replication/tcp/commands.py17
-rw-r--r--synapse/replication/tcp/handler.py19
3 files changed, 128 insertions, 63 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 50c68c86ce..2f841863ae 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -23,6 +23,7 @@ The methods that define policy are:
 """
 import abc
 import contextlib
+import itertools
 import logging
 from bisect import bisect
 from contextlib import contextmanager
@@ -188,15 +189,17 @@ class BasePresenceHandler(abc.ABC):
         """
 
     @abc.abstractmethod
-    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
-        """Get an iterable of syncing users on this worker, to send to the presence handler
+    def get_currently_syncing_users_for_replication(
+        self,
+    ) -> Iterable[Tuple[str, Optional[str]]]:
+        """Get an iterable of syncing users and devices on this worker, to send to the presence handler
 
         This is called when a replication connection is established. It should return
-        a list of user ids, which are then sent as USER_SYNC commands to inform the
-        process handling presence about those users.
+        a list of tuples of user ID & device ID, which are then sent as USER_SYNC commands
+        to inform the process handling presence about those users/devices.
 
         Returns:
-            An iterable of user_id strings.
+            An iterable of tuples of user ID and device ID.
         """
 
     async def get_state(self, target_user: UserID) -> UserPresenceState:
@@ -284,7 +287,12 @@ class BasePresenceHandler(abc.ABC):
         """
 
     async def update_external_syncs_row(  # noqa: B027 (no-op by design)
-        self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
+        self,
+        process_id: str,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        sync_time_msec: int,
     ) -> None:
         """Update the syncing users for an external process as a delta.
 
@@ -295,6 +303,7 @@ class BasePresenceHandler(abc.ABC):
                 syncing against. This allows synapse to process updates
                 as user start and stop syncing against a given process.
             user_id: The user who has started or stopped syncing
+            device_id: The user's device that has started or stopped syncing
             is_syncing: Whether or not the user is now syncing
             sync_time_msec: Time in ms when the user was last syncing
         """
@@ -425,16 +434,18 @@ class WorkerPresenceHandler(BasePresenceHandler):
             hs.config.worker.writers.presence,
         )
 
-        # The number of ongoing syncs on this process, by user id.
+        # The number of ongoing syncs on this process, by (user ID, device ID).
         # Empty if _presence_enabled is false.
-        self._user_to_num_current_syncs: Dict[str, int] = {}
+        self._user_device_to_num_current_syncs: Dict[
+            Tuple[str, Optional[str]], int
+        ] = {}
 
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
-        # user_id -> last_sync_ms. Lists the users that have stopped syncing but
-        # we haven't notified the presence writer of that yet
-        self.users_going_offline: Dict[str, int] = {}
+        # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped
+        # syncing but we haven't notified the presence writer of that yet
+        self._user_devices_going_offline: Dict[Tuple[str, Optional[str]], int] = {}
 
         self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
         self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -457,39 +468,47 @@ class WorkerPresenceHandler(BasePresenceHandler):
                 ClearUserSyncsCommand(self.instance_id)
             )
 
-    def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None:
+    def send_user_sync(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        last_sync_ms: int,
+    ) -> None:
         if self._presence_enabled:
             self.hs.get_replication_command_handler().send_user_sync(
-                self.instance_id, user_id, is_syncing, last_sync_ms
+                self.instance_id, user_id, device_id, is_syncing, last_sync_ms
             )
 
-    def mark_as_coming_online(self, user_id: str) -> None:
+    def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None:
         """A user has started syncing. Send a UserSync to the presence writer,
         unless they had recently stopped syncing.
         """
-        going_offline = self.users_going_offline.pop(user_id, None)
+        going_offline = self._user_devices_going_offline.pop((user_id, device_id), None)
         if not going_offline:
             # Safe to skip because we haven't yet told the presence writer they
             # were offline
-            self.send_user_sync(user_id, True, self.clock.time_msec())
+            self.send_user_sync(user_id, device_id, True, self.clock.time_msec())
 
-    def mark_as_going_offline(self, user_id: str) -> None:
+    def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None:
         """A user has stopped syncing. We wait before notifying the presence
         writer as its likely they'll come back soon. This allows us to avoid
         sending a stopped syncing immediately followed by a started syncing
         notification to the presence writer
         """
-        self.users_going_offline[user_id] = self.clock.time_msec()
+        self._user_devices_going_offline[(user_id, device_id)] = self.clock.time_msec()
 
     def send_stop_syncing(self) -> None:
         """Check if there are any users who have stopped syncing a while ago and
         haven't come back yet. If there are poke the presence writer about them.
         """
         now = self.clock.time_msec()
-        for user_id, last_sync_ms in list(self.users_going_offline.items()):
+        for (user_id, device_id), last_sync_ms in list(
+            self._user_devices_going_offline.items()
+        ):
             if now - last_sync_ms > UPDATE_SYNCING_USERS_MS:
-                self.users_going_offline.pop(user_id, None)
-                self.send_user_sync(user_id, False, last_sync_ms)
+                self._user_devices_going_offline.pop((user_id, device_id), None)
+                self.send_user_sync(user_id, device_id, False, last_sync_ms)
 
     async def user_syncing(
         self,
@@ -515,23 +534,23 @@ class WorkerPresenceHandler(BasePresenceHandler):
             is_sync=True,
         )
 
-        curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
-        self._user_to_num_current_syncs[user_id] = curr_sync + 1
+        curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0)
+        self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1
 
         # If this is the first in-flight sync, notify replication
-        if self._user_to_num_current_syncs[user_id] == 1:
-            self.mark_as_coming_online(user_id)
+        if self._user_device_to_num_current_syncs[(user_id, device_id)] == 1:
+            self.mark_as_coming_online(user_id, device_id)
 
         def _end() -> None:
             # We check that the user_id is in user_to_num_current_syncs because
             # user_to_num_current_syncs may have been cleared if we are
             # shutting down.
-            if user_id in self._user_to_num_current_syncs:
-                self._user_to_num_current_syncs[user_id] -= 1
+            if (user_id, device_id) in self._user_device_to_num_current_syncs:
+                self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1
 
                 # If there are no more in-flight syncs, notify replication
-                if self._user_to_num_current_syncs[user_id] == 0:
-                    self.mark_as_going_offline(user_id)
+                if self._user_device_to_num_current_syncs[(user_id, device_id)] == 0:
+                    self.mark_as_going_offline(user_id, device_id)
 
         @contextlib.contextmanager
         def _user_syncing() -> Generator[None, None, None]:
@@ -598,10 +617,12 @@ class WorkerPresenceHandler(BasePresenceHandler):
         # If this is a federation sender, notify about presence updates.
         await self.maybe_send_presence_to_interested_destinations(state_to_notify)
 
-    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+    def get_currently_syncing_users_for_replication(
+        self,
+    ) -> Iterable[Tuple[str, Optional[str]]]:
         return [
-            user_id
-            for user_id, count in self._user_to_num_current_syncs.items()
+            user_id_device_id
+            for user_id_device_id, count in self._user_device_to_num_current_syncs.items()
             if count > 0
         ]
 
@@ -723,17 +744,23 @@ class PresenceHandler(BasePresenceHandler):
 
         # Keeps track of the number of *ongoing* syncs on this process. While
         # this is non zero a user will never go offline.
-        self.user_to_num_current_syncs: Dict[str, int] = {}
+        self._user_device_to_num_current_syncs: Dict[
+            Tuple[str, Optional[str]], int
+        ] = {}
 
         # Keeps track of the number of *ongoing* syncs on other processes.
+        #
         # While any sync is ongoing on another process the user will never
         # go offline.
+        #
         # Each process has a unique identifier and an update frequency. If
         # no update is received from that process within the update period then
         # we assume that all the sync requests on that process have stopped.
-        # Stored as a dict from process_id to set of user_id, and a dict of
-        # process_id to millisecond timestamp last updated.
-        self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
+        # Stored as a dict from process_id to set of (user_id, device_id), and
+        # a dict of process_id to millisecond timestamp last updated.
+        self.external_process_to_current_syncs: Dict[
+            str, Set[Tuple[str, Optional[str]]]
+        ] = {}
         self.external_process_last_updated_ms: Dict[str, int] = {}
 
         self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
@@ -938,7 +965,10 @@ class PresenceHandler(BasePresenceHandler):
             # that were syncing on that process to see if they need to be timed
             # out.
             users_to_check.update(
-                self.external_process_to_current_syncs.pop(process_id, ())
+                user_id
+                for user_id, device_id in self.external_process_to_current_syncs.pop(
+                    process_id, ()
+                )
             )
             self.external_process_last_updated_ms.pop(process_id)
 
@@ -951,11 +981,15 @@ class PresenceHandler(BasePresenceHandler):
 
         syncing_user_ids = {
             user_id
-            for user_id, count in self.user_to_num_current_syncs.items()
+            for (user_id, _), count in self._user_device_to_num_current_syncs.items()
             if count
         }
-        for user_ids in self.external_process_to_current_syncs.values():
-            syncing_user_ids.update(user_ids)
+        syncing_user_ids.update(
+            user_id
+            for user_id, _ in itertools.chain(
+                *self.external_process_to_current_syncs.values()
+            )
+        )
 
         changes = handle_timeouts(
             states,
@@ -1013,8 +1047,8 @@ class PresenceHandler(BasePresenceHandler):
         if not affect_presence or not self._presence_enabled:
             return _NullContextManager()
 
-        curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
-        self.user_to_num_current_syncs[user_id] = curr_sync + 1
+        curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0)
+        self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1
 
         # Note that this causes last_active_ts to be incremented which is not
         # what the spec wants.
@@ -1027,7 +1061,7 @@ class PresenceHandler(BasePresenceHandler):
 
         async def _end() -> None:
             try:
-                self.user_to_num_current_syncs[user_id] -= 1
+                self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1
 
                 prev_state = await self.current_state_for_user(user_id)
                 await self._update_states(
@@ -1049,12 +1083,19 @@ class PresenceHandler(BasePresenceHandler):
 
         return _user_syncing()
 
-    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+    def get_currently_syncing_users_for_replication(
+        self,
+    ) -> Iterable[Tuple[str, Optional[str]]]:
         # since we are the process handling presence, there is nothing to do here.
         return []
 
     async def update_external_syncs_row(
-        self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
+        self,
+        process_id: str,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        sync_time_msec: int,
     ) -> None:
         """Update the syncing users for an external process as a delta.
 
@@ -1063,6 +1104,7 @@ class PresenceHandler(BasePresenceHandler):
                 syncing against. This allows synapse to process updates
                 as user start and stop syncing against a given process.
             user_id: The user who has started or stopped syncing
+            device_id: The user's device that has started or stopped syncing
             is_syncing: Whether or not the user is now syncing
             sync_time_msec: Time in ms when the user was last syncing
         """
@@ -1073,26 +1115,27 @@ class PresenceHandler(BasePresenceHandler):
                 process_id, set()
             )
 
-            # USER_SYNC is sent when a user starts or stops syncing on a remote
-            # process. (But only for the initial and last device.)
+            # USER_SYNC is sent when a user's device starts or stops syncing on
+            # a remote # process. (But only for the initial and last sync for that
+            # device.)
             #
-            # When a user *starts* syncing it also calls set_state(...) which
+            # When a device *starts* syncing it also calls set_state(...) which
             # will update the state, last_active_ts, and last_user_sync_ts.
-            # Simply ensure the user is tracked as syncing in this case.
+            # Simply ensure the user & device is tracked as syncing in this case.
             #
-            # When a user *stops* syncing, update the last_user_sync_ts and mark
+            # When a device *stops* syncing, update the last_user_sync_ts and mark
             # them as no longer syncing. Note this doesn't quite match the
             # monolith behaviour, which updates last_user_sync_ts at the end of
             # every sync, not just the last in-flight sync.
-            if is_syncing and user_id not in process_presence:
-                process_presence.add(user_id)
-            elif not is_syncing and user_id in process_presence:
+            if is_syncing and (user_id, device_id) not in process_presence:
+                process_presence.add((user_id, device_id))
+            elif not is_syncing and (user_id, device_id) in process_presence:
                 new_state = prev_state.copy_and_replace(
                     last_user_sync_ts=sync_time_msec
                 )
                 await self._update_states([new_state])
 
-                process_presence.discard(user_id)
+                process_presence.discard((user_id, device_id))
 
             self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
 
@@ -1106,7 +1149,9 @@ class PresenceHandler(BasePresenceHandler):
             process_presence = self.external_process_to_current_syncs.pop(
                 process_id, set()
             )
-            prev_states = await self.current_state_for_users(process_presence)
+            prev_states = await self.current_state_for_users(
+                {user_id for user_id, device_id in process_presence}
+            )
             time_now_ms = self.clock.time_msec()
 
             await self._update_states(
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 58a871c6d9..e616b5e1c8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -267,27 +267,38 @@ class UserSyncCommand(Command):
     NAME = "USER_SYNC"
 
     def __init__(
-        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+        self,
+        instance_id: str,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        last_sync_ms: int,
     ):
         self.instance_id = instance_id
         self.user_id = user_id
+        self.device_id = device_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
-        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
+        device_id: Optional[str]
+        instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4)
+
+        if device_id == "None":
+            device_id = None
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms))
 
     def to_line(self) -> str:
         return " ".join(
             (
                 self.instance_id,
                 self.user_id,
+                str(self.device_id),
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
             )
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 92c5a55acc..d9045d7b73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -428,7 +428,11 @@ class ReplicationCommandHandler:
 
         if self._is_presence_writer:
             return self._presence_handler.update_external_syncs_row(
-                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+                cmd.instance_id,
+                cmd.user_id,
+                cmd.device_id,
+                cmd.is_syncing,
+                cmd.last_sync_ms,
             )
         else:
             return None
@@ -699,9 +703,9 @@ class ReplicationCommandHandler:
         )
 
         now = self._clock.time_msec()
-        for user_id in currently_syncing:
+        for user_id, device_id in currently_syncing:
             connection.send_command(
-                UserSyncCommand(self._instance_id, user_id, True, now)
+                UserSyncCommand(self._instance_id, user_id, device_id, True, now)
             )
 
     def lost_connection(self, connection: IReplicationConnection) -> None:
@@ -753,11 +757,16 @@ class ReplicationCommandHandler:
         self.send_command(FederationAckCommand(self._instance_name, token))
 
     def send_user_sync(
-        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+        self,
+        instance_id: str,
+        user_id: str,
+        device_id: Optional[str],
+        is_syncing: bool,
+        last_sync_ms: int,
     ) -> None:
         """Poke the master that a user has started/stopped syncing."""
         self.send_command(
-            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+            UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms)
         )
 
     def send_user_ip(