summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-08-28 13:08:49 -0400
committerGitHub <noreply@github.com>2023-08-28 13:08:49 -0400
commit40901af5e096cb10ab69141875b071b4ea4ed1e0 (patch)
treeddc310741a9c98bda2435b848832f9e3be49bf35
parentCombine logic about not overriding BUSY presence. (#16170) (diff)
downloadsynapse-40901af5e096cb10ab69141875b071b4ea4ed1e0.tar.xz
Pass the device ID around in the presence handler (#16171)
Refactoring to pass the device ID (in addition to the user ID) through
the presence handler (specifically the `user_syncing`, `set_state`,
and `bump_presence_active_time` methods and their replication
versions).
-rw-r--r--changelog.d/16171.misc1
-rw-r--r--synapse/handlers/events.py1
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/handlers/presence.py46
-rw-r--r--synapse/replication/http/presence.py11
-rw-r--r--synapse/rest/client/presence.py2
-rw-r--r--synapse/rest/client/read_marker.py4
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/rest/client/room.py4
-rw-r--r--synapse/rest/client/sync.py1
-rw-r--r--tests/handlers/test_presence.py38
11 files changed, 91 insertions, 30 deletions
diff --git a/changelog.d/16171.misc b/changelog.d/16171.misc
new file mode 100644
index 0000000000..4d709cb56e
--- /dev/null
+++ b/changelog.d/16171.misc
@@ -0,0 +1 @@
+Track per-device information in the presence code.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 33359f6ed7..d12803bf0f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -67,6 +67,7 @@ class EventStreamHandler:
 
         context = await presence_handler.user_syncing(
             requester.user.to_string(),
+            requester.device_id,
             affect_presence=affect_presence,
             presence_state=PresenceState.ONLINE,
         )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3184bfb047..4a15c76a7b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1921,7 +1921,10 @@ class EventCreationHandler:
                 # We don't want to block sending messages on any presence code. This
                 # matters as sometimes presence code can take a while.
                 run_as_background_process(
-                    "bump_presence_active_time", self._bump_active_time, requester.user
+                    "bump_presence_active_time",
+                    self._bump_active_time,
+                    requester.user,
+                    requester.device_id,
                 )
 
         async def _notify() -> None:
@@ -1958,10 +1961,10 @@ class EventCreationHandler:
         logger.info("maybe_kick_guest_users %r", current_state)
         await self.hs.get_room_member_handler().kick_guest_users(current_state)
 
-    async def _bump_active_time(self, user: UserID) -> None:
+    async def _bump_active_time(self, user: UserID, device_id: Optional[str]) -> None:
         try:
             presence = self.hs.get_presence_handler()
-            await presence.bump_presence_active_time(user)
+            await presence.bump_presence_active_time(user, device_id)
         except Exception:
             logger.exception("Error bumping presence active time")
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index c395dcdb43..50c68c86ce 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -165,7 +165,11 @@ class BasePresenceHandler(abc.ABC):
 
     @abc.abstractmethod
     async def user_syncing(
-        self, user_id: str, affect_presence: bool, presence_state: str
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        affect_presence: bool,
+        presence_state: str,
     ) -> ContextManager[None]:
         """Returns a context manager that should surround any stream requests
         from the user.
@@ -176,6 +180,7 @@ class BasePresenceHandler(abc.ABC):
 
         Args:
             user_id: the user that is starting a sync
+            device_id: the user's device that is starting a sync
             affect_presence: If false this function will be a no-op.
                 Useful for streams that are not associated with an actual
                 client that is being used by a user.
@@ -252,6 +257,7 @@ class BasePresenceHandler(abc.ABC):
     async def set_state(
         self,
         target_user: UserID,
+        device_id: Optional[str],
         state: JsonDict,
         force_notify: bool = False,
         is_sync: bool = False,
@@ -260,6 +266,7 @@ class BasePresenceHandler(abc.ABC):
 
         Args:
             target_user: The ID of the user to set the presence state of.
+            device_id: the device that the user is setting the presence state of.
             state: The presence state as a JSON dictionary.
             force_notify: Whether to force notification of the update to clients.
             is_sync: True if this update was from a sync, which results in
@@ -269,7 +276,9 @@ class BasePresenceHandler(abc.ABC):
         """
 
     @abc.abstractmethod
-    async def bump_presence_active_time(self, user: UserID) -> None:
+    async def bump_presence_active_time(
+        self, user: UserID, device_id: Optional[str]
+    ) -> None:
         """We've seen the user do something that indicates they're interacting
         with the app.
         """
@@ -381,7 +390,9 @@ class BasePresenceHandler(abc.ABC):
         # We set force_notify=True here so that this presence update is guaranteed to
         # increment the presence stream ID (which resending the current user's presence
         # otherwise would not do).
-        await self.set_state(UserID.from_string(user_id), state, force_notify=True)
+        await self.set_state(
+            UserID.from_string(user_id), None, state, force_notify=True
+        )
 
     async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
         raise NotImplementedError(
@@ -481,7 +492,11 @@ class WorkerPresenceHandler(BasePresenceHandler):
                 self.send_user_sync(user_id, False, last_sync_ms)
 
     async def user_syncing(
-        self, user_id: str, affect_presence: bool, presence_state: str
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        affect_presence: bool,
+        presence_state: str,
     ) -> ContextManager[None]:
         """Record that a user is syncing.
 
@@ -495,6 +510,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
         # what the spec wants.
         await self.set_state(
             UserID.from_string(user_id),
+            device_id,
             state={"presence": presence_state},
             is_sync=True,
         )
@@ -592,6 +608,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
     async def set_state(
         self,
         target_user: UserID,
+        device_id: Optional[str],
         state: JsonDict,
         force_notify: bool = False,
         is_sync: bool = False,
@@ -600,6 +617,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         Args:
             target_user: The ID of the user to set the presence state of.
+            device_id: the device that the user is setting the presence state of.
             state: The presence state as a JSON dictionary.
             force_notify: Whether to force notification of the update to clients.
             is_sync: True if this update was from a sync, which results in
@@ -622,12 +640,15 @@ class WorkerPresenceHandler(BasePresenceHandler):
         await self._set_state_client(
             instance_name=self._presence_writer_instance,
             user_id=user_id,
+            device_id=device_id,
             state=state,
             force_notify=force_notify,
             is_sync=is_sync,
         )
 
-    async def bump_presence_active_time(self, user: UserID) -> None:
+    async def bump_presence_active_time(
+        self, user: UserID, device_id: Optional[str]
+    ) -> None:
         """We've seen the user do something that indicates they're interacting
         with the app.
         """
@@ -638,7 +659,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
         # Proxy request to instance that writes presence
         user_id = user.to_string()
         await self._bump_active_client(
-            instance_name=self._presence_writer_instance, user_id=user_id
+            instance_name=self._presence_writer_instance,
+            user_id=user_id,
+            device_id=device_id,
         )
 
 
@@ -943,7 +966,9 @@ class PresenceHandler(BasePresenceHandler):
 
         return await self._update_states(changes)
 
-    async def bump_presence_active_time(self, user: UserID) -> None:
+    async def bump_presence_active_time(
+        self, user: UserID, device_id: Optional[str]
+    ) -> None:
         """We've seen the user do something that indicates they're interacting
         with the app.
         """
@@ -966,6 +991,7 @@ class PresenceHandler(BasePresenceHandler):
     async def user_syncing(
         self,
         user_id: str,
+        device_id: Optional[str],
         affect_presence: bool = True,
         presence_state: str = PresenceState.ONLINE,
     ) -> ContextManager[None]:
@@ -977,7 +1003,8 @@ class PresenceHandler(BasePresenceHandler):
         when users disconnect/reconnect.
 
         Args:
-            user_id
+            user_id: the user that is starting a sync
+            device_id: the user's device that is starting a sync
             affect_presence: If false this function will be a no-op.
                 Useful for streams that are not associated with an actual
                 client that is being used by a user.
@@ -993,6 +1020,7 @@ class PresenceHandler(BasePresenceHandler):
         # what the spec wants.
         await self.set_state(
             UserID.from_string(user_id),
+            device_id,
             state={"presence": presence_state},
             is_sync=True,
         )
@@ -1163,6 +1191,7 @@ class PresenceHandler(BasePresenceHandler):
     async def set_state(
         self,
         target_user: UserID,
+        device_id: Optional[str],
         state: JsonDict,
         force_notify: bool = False,
         is_sync: bool = False,
@@ -1171,6 +1200,7 @@ class PresenceHandler(BasePresenceHandler):
 
         Args:
             target_user: The ID of the user to set the presence state of.
+            device_id: the device that the user is setting the presence state of.
             state: The presence state as a JSON dictionary.
             force_notify: Whether to force notification of the update to clients.
             is_sync: True if this update was from a sync, which results in
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index a24fb9310b..6c9e79fb07 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -51,14 +51,14 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
         self._presence_handler = hs.get_presence_handler()
 
     @staticmethod
-    async def _serialize_payload(user_id: str) -> JsonDict:  # type: ignore[override]
-        return {}
+    async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict:  # type: ignore[override]
+        return {"device_id": device_id}
 
     async def _handle_request(  # type: ignore[override]
         self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
         await self._presence_handler.bump_presence_active_time(
-            UserID.from_string(user_id)
+            UserID.from_string(user_id), content.get("device_id")
         )
 
         return (200, {})
@@ -95,11 +95,13 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
     @staticmethod
     async def _serialize_payload(  # type: ignore[override]
         user_id: str,
+        device_id: Optional[str],
         state: JsonDict,
         force_notify: bool = False,
         is_sync: bool = False,
     ) -> JsonDict:
         return {
+            "device_id": device_id,
             "state": state,
             "force_notify": force_notify,
             "is_sync": is_sync,
@@ -110,6 +112,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
     ) -> Tuple[int, JsonDict]:
         await self._presence_handler.set_state(
             UserID.from_string(user_id),
+            content.get("device_id"),
             content["state"],
             content["force_notify"],
             content.get("is_sync", False),
diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 8e193330f8..d578faa969 100644
--- a/synapse/rest/client/presence.py
+++ b/synapse/rest/client/presence.py
@@ -97,7 +97,7 @@ class PresenceStatusRestServlet(RestServlet):
             raise SynapseError(400, "Unable to parse state")
 
         if self._use_presence:
-            await self.presence_handler.set_state(user, state)
+            await self.presence_handler.set_state(user, requester.device_id, state)
 
         return 200, {}
 
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 4f96e51eeb..1707e51972 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -52,7 +52,9 @@ class ReadMarkerRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        await self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(
+            requester.user, requester.device_id
+        )
 
         body = parse_json_object_from_request(request)
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 316e7b9982..869a374459 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -94,7 +94,9 @@ class ReceiptRestServlet(RestServlet):
                     Codes.INVALID_PARAM,
                 )
 
-        await self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(
+            requester.user, requester.device_id
+        )
 
         if receipt_type == ReceiptTypes.FULLY_READ:
             await self.read_marker_handler.received_client_read_marker(
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index dc498001e4..553938ce9d 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -1229,7 +1229,9 @@ class RoomTypingRestServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        await self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(
+            requester.user, requester.device_id
+        )
 
         # Limit timeout to stop people from setting silly typing timeouts.
         timeout = min(content.get("timeout", 30000), 120000)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d7854ed4fd..42bdd3bb10 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -205,6 +205,7 @@ class SyncRestServlet(RestServlet):
 
         context = await self.presence_handler.user_syncing(
             user.to_string(),
+            requester.device_id,
             affect_presence=affect_presence,
             presence_state=set_presence,
         )
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index a3fdcf7f93..a987267308 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -524,6 +524,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = f"@test:{self.hs.config.server.server_name}"
+        self.device_id = "dev-1"
 
         # Move the reactor to the initial time.
         self.reactor.advance(1000)
@@ -608,7 +609,10 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2)
         self.get_success(
             presence_handler.user_syncing(
-                self.user_id, sync_state != PresenceState.OFFLINE, sync_state
+                self.user_id,
+                self.device_id,
+                sync_state != PresenceState.OFFLINE,
+                sync_state,
             )
         )
 
@@ -632,6 +636,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
 class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
     user_id = "@test:server"
     user_id_obj = UserID.from_string(user_id)
+    device_id = "dev-1"
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.presence_handler = hs.get_presence_handler()
@@ -652,7 +657,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
 
         self.get_success(
             worker_presence_handler.user_syncing(
-                self.user_id, True, PresenceState.ONLINE
+                self.user_id, self.device_id, True, PresenceState.ONLINE
             ),
             by=0.1,
         )
@@ -708,7 +713,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # Mark user as offline
         self.get_success(
             self.presence_handler.set_state(
-                self.user_id_obj, {"presence": PresenceState.OFFLINE}
+                self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE}
             )
         )
 
@@ -740,7 +745,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # Mark user as online again
         self.get_success(
             self.presence_handler.set_state(
-                self.user_id_obj, {"presence": PresenceState.ONLINE}
+                self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE}
             )
         )
 
@@ -769,7 +774,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
 
         self.get_success(
             self.presence_handler.user_syncing(
-                self.user_id, False, PresenceState.ONLINE
+                self.user_id, self.device_id, False, PresenceState.ONLINE
             )
         )
 
@@ -786,7 +791,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
 
         self.get_success(
-            self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            )
         )
 
         state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
@@ -800,7 +807,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
 
         self.get_success(
-            self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            )
         )
 
         state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
@@ -838,7 +847,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # /presence/*.
         self.get_success(
             worker_to_sync_against.get_presence_handler().user_syncing(
-                self.user_id, True, PresenceState.ONLINE
+                self.user_id, self.device_id, True, PresenceState.ONLINE
             ),
             by=0.1,
         )
@@ -875,6 +884,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self.get_success(
             self.presence_handler.set_state(
                 self.user_id_obj,
+                self.device_id,
                 {"presence": state, "status_msg": status_msg},
             )
         )
@@ -1116,7 +1126,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Mark test2 as online, test will be offline with a last_active of 0
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test2:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
         self.reactor.pump([0])  # Wait for presence updates to be handled
@@ -1163,7 +1175,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Mark test as online
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
 
@@ -1171,7 +1185,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Note we don't join them to the room yet
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test2:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )