summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c05a170c55..901e2310b7 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -45,13 +45,13 @@ from synapse.types import (
     JsonDict,
     StreamKeyType,
     StreamToken,
-    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.cancellation import cancellable
 from synapse.util.metrics import measure_func
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -118,11 +119,12 @@ class DeviceWorkerHandler:
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
-        set_tag("device", device)
-        set_tag("ips", ips)
+        set_tag("device", str(device))
+        set_tag("ips", str(ips))
 
         return device
 
+    @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
     ) -> Collection[str]:
@@ -162,6 +164,7 @@ class DeviceWorkerHandler:
 
     @trace
     @measure_func("device.get_user_ids_changed")
+    @cancellable
     async def get_user_ids_changed(
         self, user_id: str, from_token: StreamToken
     ) -> JsonDict:
@@ -170,7 +173,7 @@ class DeviceWorkerHandler:
         """
 
         set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
+        set_tag("from_token", str(from_token))
         now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
@@ -309,6 +312,7 @@ class DeviceHandler(DeviceWorkerHandler):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -319,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler):
             self.device_list_updater.incoming_device_list_update,
         )
 
-        hs.get_distributor().observe("user_left_room", self.user_left_room)
-
         # Whether `_handle_new_device_update_async` is currently processing.
         self._handle_new_device_update_is_processing = False
 
@@ -564,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler):
             StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
         )
 
-    async def user_left_room(self, user: UserID, room_id: str) -> None:
-        user_id = user.to_string()
-        room_ids = await self.store.get_rooms_for_user(user_id)
-        if not room_ids:
-            # We no longer share rooms with this user, so we'll no longer
-            # receive device updates. Mark this in DB.
-            await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
     async def store_dehydrated_device(
         self,
         user_id: str,
@@ -693,8 +687,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
-                        joined_user_ids = await self.store.get_users_in_room(room_id)
-                        hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                        hosts = set(
+                            await self._storage_controllers.state.get_current_hosts_in_room(
+                                room_id
+                            )
+                        )
                         hosts.discard(self.server_name)
 
                     # Check if we've already sent this update to some hosts
@@ -747,7 +744,13 @@ def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
-    device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+    device.update(
+        {
+            "last_seen_user_agent": ip.get("user_agent"),
+            "last_seen_ts": ip.get("last_seen"),
+            "last_seen_ip": ip.get("ip"),
+        }
+    )
 
 
 class DeviceListUpdater:
@@ -795,7 +798,7 @@ class DeviceListUpdater:
         """
 
         set_tag("origin", origin)
-        set_tag("edu_content", edu_content)
+        set_tag("edu_content", str(edu_content))
         user_id = edu_content.pop("user_id")
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints