diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c05a170c55..901e2310b7 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -45,13 +45,13 @@ from synapse.types import (
JsonDict,
StreamKeyType,
StreamToken,
- UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -118,11 +119,12 @@ class DeviceWorkerHandler:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
- set_tag("device", device)
- set_tag("ips", ips)
+ set_tag("device", str(device))
+ set_tag("ips", str(ips))
return device
+ @cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
@@ -162,6 +164,7 @@ class DeviceWorkerHandler:
@trace
@measure_func("device.get_user_ids_changed")
+ @cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
@@ -170,7 +173,7 @@ class DeviceWorkerHandler:
"""
set_tag("user_id", user_id)
- set_tag("from_token", from_token)
+ set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -309,6 +312,7 @@ class DeviceHandler(DeviceWorkerHandler):
super().__init__(hs)
self.federation_sender = hs.get_federation_sender()
+ self._storage_controllers = hs.get_storage_controllers()
self.device_list_updater = DeviceListUpdater(hs, self)
@@ -319,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler):
self.device_list_updater.incoming_device_list_update,
)
- hs.get_distributor().observe("user_left_room", self.user_left_room)
-
# Whether `_handle_new_device_update_async` is currently processing.
self._handle_new_device_update_is_processing = False
@@ -564,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler):
StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
)
- async def user_left_room(self, user: UserID, room_id: str) -> None:
- user_id = user.to_string()
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- # We no longer share rooms with this user, so we'll no longer
- # receive device updates. Mark this in DB.
- await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
async def store_dehydrated_device(
self,
user_id: str,
@@ -693,8 +687,11 @@ class DeviceHandler(DeviceWorkerHandler):
# Ignore any users that aren't ours
if self.hs.is_mine_id(user_id):
- joined_user_ids = await self.store.get_users_in_room(room_id)
- hosts = {get_domain_from_id(u) for u in joined_user_ids}
+ hosts = set(
+ await self._storage_controllers.state.get_current_hosts_in_room(
+ room_id
+ )
+ )
hosts.discard(self.server_name)
# Check if we've already sent this update to some hosts
@@ -747,7 +744,13 @@ def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+ device.update(
+ {
+ "last_seen_user_agent": ip.get("user_agent"),
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ }
+ )
class DeviceListUpdater:
@@ -795,7 +798,7 @@ class DeviceListUpdater:
"""
set_tag("origin", origin)
- set_tag("edu_content", edu_content)
+ set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
|