summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12213.bugfix1
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/presence.py56
-rw-r--r--synapse/rest/client/sync.py9
-rw-r--r--tests/handlers/test_presence.py79
5 files changed, 133 insertions, 18 deletions
diff --git a/changelog.d/12213.bugfix b/changelog.d/12213.bugfix
new file mode 100644
index 0000000000..9278e3a9c1
--- /dev/null
+++ b/changelog.d/12213.bugfix
@@ -0,0 +1 @@
+Prevent a sync request from removing a user's busy presence status.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d2ccb5c5d3..e89c4df314 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -16,7 +16,7 @@ import logging
 import random
 from typing import TYPE_CHECKING, Iterable, List, Optional
 
-from synapse.api.constants import EduTypes, EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig
@@ -67,7 +67,9 @@ class EventStreamHandler:
         presence_handler = self.hs.get_presence_handler()
 
         context = await presence_handler.user_syncing(
-            auth_user_id, affect_presence=affect_presence
+            auth_user_id,
+            affect_presence=affect_presence,
+            presence_state=PresenceState.ONLINE,
         )
         with context:
             if timeout:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 209a4b0e52..d078162c29 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -151,7 +151,7 @@ class BasePresenceHandler(abc.ABC):
 
     @abc.abstractmethod
     async def user_syncing(
-        self, user_id: str, affect_presence: bool
+        self, user_id: str, affect_presence: bool, presence_state: str
     ) -> ContextManager[None]:
         """Returns a context manager that should surround any stream requests
         from the user.
@@ -165,6 +165,7 @@ class BasePresenceHandler(abc.ABC):
             affect_presence: If false this function will be a no-op.
                 Useful for streams that are not associated with an actual
                 client that is being used by a user.
+            presence_state: The presence state indicated in the sync request
         """
 
     @abc.abstractmethod
@@ -228,6 +229,11 @@ class BasePresenceHandler(abc.ABC):
 
         return states
 
+    async def current_state_for_user(self, user_id: str) -> UserPresenceState:
+        """Get the current presence state for a user."""
+        res = await self.current_state_for_users([user_id])
+        return res[user_id]
+
     @abc.abstractmethod
     async def set_state(
         self,
@@ -461,7 +467,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
                 self.send_user_sync(user_id, False, last_sync_ms)
 
     async def user_syncing(
-        self, user_id: str, affect_presence: bool
+        self, user_id: str, affect_presence: bool, presence_state: str
     ) -> ContextManager[None]:
         """Record that a user is syncing.
 
@@ -471,6 +477,17 @@ class WorkerPresenceHandler(BasePresenceHandler):
         if not affect_presence or not self._presence_enabled:
             return _NullContextManager()
 
+        prev_state = await self.current_state_for_user(user_id)
+        if prev_state != PresenceState.BUSY:
+            # We set state here but pass ignore_status_msg = True as we don't want to
+            # cause the status message to be cleared.
+            # Note that this causes last_active_ts to be incremented which is not
+            # what the spec wants: see comment in the BasePresenceHandler version
+            # of this function.
+            await self.set_state(
+                UserID.from_string(user_id), {"presence": presence_state}, True
+            )
+
         curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
         self._user_to_num_current_syncs[user_id] = curr_sync + 1
 
@@ -942,7 +959,10 @@ class PresenceHandler(BasePresenceHandler):
         await self._update_states([prev_state.copy_and_replace(**new_fields)])
 
     async def user_syncing(
-        self, user_id: str, affect_presence: bool = True
+        self,
+        user_id: str,
+        affect_presence: bool = True,
+        presence_state: str = PresenceState.ONLINE,
     ) -> ContextManager[None]:
         """Returns a context manager that should surround any stream requests
         from the user.
@@ -956,6 +976,7 @@ class PresenceHandler(BasePresenceHandler):
             affect_presence: If false this function will be a no-op.
                 Useful for streams that are not associated with an actual
                 client that is being used by a user.
+            presence_state: The presence state indicated in the sync request
         """
         # Override if it should affect the user's presence, if presence is
         # disabled.
@@ -967,9 +988,25 @@ class PresenceHandler(BasePresenceHandler):
             self.user_to_num_current_syncs[user_id] = curr_sync + 1
 
             prev_state = await self.current_state_for_user(user_id)
+
+            # If they're busy then they don't stop being busy just by syncing,
+            # so just update the last sync time.
+            if prev_state.state != PresenceState.BUSY:
+                # XXX: We set_state separately here and just update the last_active_ts above
+                # This keeps the logic as similar as possible between the worker and single
+                # process modes. Using set_state will actually cause last_active_ts to be
+                # updated always, which is not what the spec calls for, but synapse has done
+                # this for... forever, I think.
+                await self.set_state(
+                    UserID.from_string(user_id), {"presence": presence_state}, True
+                )
+                # Retrieve the new state for the logic below. This should come from the
+                # in-memory cache.
+                prev_state = await self.current_state_for_user(user_id)
+
+            # To keep the single process behaviour consistent with worker mode, run the
+            # same logic as `update_external_syncs_row`, even though it looks weird.
             if prev_state.state == PresenceState.OFFLINE:
-                # If they're currently offline then bring them online, otherwise
-                # just update the last sync times.
                 await self._update_states(
                     [
                         prev_state.copy_and_replace(
@@ -979,6 +1016,10 @@ class PresenceHandler(BasePresenceHandler):
                         )
                     ]
                 )
+            # otherwise, set the new presence state & update the last sync time,
+            # but don't update last_active_ts as this isn't an indication that
+            # they've been active (even though it's probably been updated by
+            # set_state above)
             else:
                 await self._update_states(
                     [
@@ -1086,11 +1127,6 @@ class PresenceHandler(BasePresenceHandler):
             )
             self.external_process_last_updated_ms.pop(process_id, None)
 
-    async def current_state_for_user(self, user_id: str) -> UserPresenceState:
-        """Get the current presence state for a user."""
-        res = await self.current_state_for_users([user_id])
-        return res[user_id]
-
     async def _persist_and_notify(self, states: List[UserPresenceState]) -> None:
         """Persist states in the database, poke the notifier and send to
         interested remote servers
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 2e25e8638b..e8772f86e7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -180,13 +180,10 @@ class SyncRestServlet(RestServlet):
 
         affect_presence = set_presence != PresenceState.OFFLINE
 
-        if affect_presence:
-            await self.presence_handler.set_state(
-                user, {"presence": set_presence}, True
-            )
-
         context = await self.presence_handler.user_syncing(
-            user.to_string(), affect_presence=affect_presence
+            user.to_string(),
+            affect_presence=affect_presence,
+            presence_state=set_presence,
         )
         with context:
             sync_result = await self.sync_handler.wait_for_sync_for_user(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index b2ed9cbe37..c96dc6caf2 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -657,6 +657,85 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
         # Mark user as online and `status_msg = None`
         self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
 
+    def test_set_presence_from_syncing_not_set(self):
+        """Test that presence is not set by syncing if affect_presence is false"""
+        user_id = "@test:server"
+        status_msg = "I'm here!"
+
+        self._set_presencestate_with_status_msg(
+            user_id, PresenceState.UNAVAILABLE, status_msg
+        )
+
+        self.get_success(
+            self.presence_handler.user_syncing(user_id, False, PresenceState.ONLINE)
+        )
+
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        # we should still be unavailable
+        self.assertEqual(state.state, PresenceState.UNAVAILABLE)
+        # and status message should still be the same
+        self.assertEqual(state.status_msg, status_msg)
+
+    def test_set_presence_from_syncing_is_set(self):
+        """Test that presence is set by syncing if affect_presence is true"""
+        user_id = "@test:server"
+        status_msg = "I'm here!"
+
+        self._set_presencestate_with_status_msg(
+            user_id, PresenceState.UNAVAILABLE, status_msg
+        )
+
+        self.get_success(
+            self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+        )
+
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        # we should now be online
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
+    def test_set_presence_from_syncing_keeps_status(self):
+        """Test that presence set by syncing retains status message"""
+        user_id = "@test:server"
+        status_msg = "I'm here!"
+
+        self._set_presencestate_with_status_msg(
+            user_id, PresenceState.UNAVAILABLE, status_msg
+        )
+
+        self.get_success(
+            self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+        )
+
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        # our status message should be the same as it was before
+        self.assertEqual(state.status_msg, status_msg)
+
+    def test_set_presence_from_syncing_keeps_busy(self):
+        """Test that presence set by syncing doesn't affect busy status"""
+        # while this isn't the default
+        self.presence_handler._busy_presence_enabled = True
+
+        user_id = "@test:server"
+        status_msg = "I'm busy!"
+
+        self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg)
+
+        self.get_success(
+            self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+        )
+
+        state = self.get_success(
+            self.presence_handler.get_state(UserID.from_string(user_id))
+        )
+        # we should still be busy
+        self.assertEqual(state.state, PresenceState.BUSY)
+
     def _set_presencestate_with_status_msg(
         self, user_id: str, state: str, status_msg: Optional[str]
     ):