diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 1a8379854c..b1e55e1b9e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -45,13 +45,13 @@ from synapse.types import (
JsonDict,
StreamKeyType,
StreamToken,
- UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
class DeviceWorkerHandler:
+ device_list_updater: "DeviceListWorkerUpdater"
+
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
@@ -74,6 +76,9 @@ class DeviceWorkerHandler:
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
+
+ self.device_list_updater = DeviceListWorkerUpdater(hs)
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -98,6 +103,19 @@ class DeviceWorkerHandler:
log_kv(device_map)
return devices
+ async def get_dehydrated_device(
+ self, user_id: str
+ ) -> Optional[Tuple[str, JsonDict]]:
+ """Retrieve the information for a dehydrated device.
+
+ Args:
+ user_id: the user whose dehydrated device we are looking for
+ Returns:
+ a tuple whose first item is the device ID, and the second item is
+ the dehydrated device information
+ """
+ return await self.store.get_dehydrated_device(user_id)
+
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
"""Retrieve the given device
@@ -123,9 +141,10 @@ class DeviceWorkerHandler:
return device
+ @cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
- ) -> Collection[str]:
+ ) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -162,6 +181,7 @@ class DeviceWorkerHandler:
@trace
@measure_func("device.get_user_ids_changed")
+ @cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
@@ -192,7 +212,9 @@ class DeviceWorkerHandler:
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
- current_state_ids = await self._state_storage.get_current_state_ids(room_id)
+ current_state_ids = await self._state_storage.get_current_state_ids(
+ room_id, await_full_state=False
+ )
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@@ -231,7 +253,8 @@ class DeviceWorkerHandler:
# mapping from event_id -> state_dict
prev_state_ids = await self._state_storage.get_state_ids_for_events(
- event_ids
+ event_ids,
+ await_full_state=False,
)
# Check if we've joined the room? If so we just blindly add all the users to
@@ -267,11 +290,9 @@ class DeviceWorkerHandler:
possibly_left = possibly_changed | possibly_left
# Double check if we still share rooms with the given user.
- users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
- possibly_left
- )
+ users_rooms = await self.store.get_rooms_for_users(possibly_left)
for changed_user_id, entries in users_rooms.items():
- if any(e.room_id in room_ids for e in entries):
+ if any(rid in room_ids for rid in entries):
possibly_left.discard(changed_user_id)
else:
possibly_joined.discard(changed_user_id)
@@ -303,12 +324,26 @@ class DeviceWorkerHandler:
"self_signing_key": self_signing_key,
}
+ async def handle_room_un_partial_stated(self, room_id: str) -> None:
+ """Handles sending appropriate device list updates in a room that has
+ gone from partial to full state.
+ """
+
+ # TODO(faster_joins): worker mode support
+ # https://github.com/matrix-org/synapse/issues/12994
+ logger.error(
+ "Trying handling device list state for partial join: not supported on workers."
+ )
+
class DeviceHandler(DeviceWorkerHandler):
+ device_list_updater: "DeviceListUpdater"
+
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation_sender = hs.get_federation_sender()
+ self._storage_controllers = hs.get_storage_controllers()
self.device_list_updater = DeviceListUpdater(hs, self)
@@ -319,8 +354,6 @@ class DeviceHandler(DeviceWorkerHandler):
self.device_list_updater.incoming_device_list_update,
)
- hs.get_distributor().observe("user_left_room", self.user_left_room)
-
# Whether `_handle_new_device_update_async` is currently processing.
self._handle_new_device_update_is_processing = False
@@ -564,14 +597,6 @@ class DeviceHandler(DeviceWorkerHandler):
StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
)
- async def user_left_room(self, user: UserID, room_id: str) -> None:
- user_id = user.to_string()
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- # We no longer share rooms with this user, so we'll no longer
- # receive device updates. Mark this in DB.
- await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
async def store_dehydrated_device(
self,
user_id: str,
@@ -600,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, [old_device_id])
return device_id
- async def get_dehydrated_device(
- self, user_id: str
- ) -> Optional[Tuple[str, JsonDict]]:
- """Retrieve the information for a dehydrated device.
-
- Args:
- user_id: the user whose dehydrated device we are looking for
- Returns:
- a tuple whose first item is the device ID, and the second item is
- the dehydrated device information
- """
- return await self.store.get_dehydrated_device(user_id)
-
async def rehydrate_device(
self, user_id: str, access_token: str, device_id: str
) -> dict:
@@ -676,13 +688,33 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to: Set[str] = set()
try:
+ stream_id, room_id = await self.store.get_device_change_last_converted_pos()
+
while True:
self._handle_new_device_update_new_data = False
- rows = await self.store.get_uncoverted_outbound_room_pokes()
+ max_stream_id = self.store.get_device_stream_token()
+ rows = await self.store.get_uncoverted_outbound_room_pokes(
+ stream_id, room_id
+ )
if not rows:
# If the DB returned nothing then there is nothing left to
# do, *unless* a new device list update happened during the
# DB query.
+
+ # Advance `(stream_id, room_id)`.
+ # `max_stream_id` comes from *before* the query for unconverted
+ # rows, which means that any unconverted rows must have a larger
+ # stream ID.
+ if max_stream_id > stream_id:
+ stream_id, room_id = max_stream_id, ""
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+ else:
+ assert max_stream_id == stream_id
+ # Avoid moving `room_id` backwards.
+ pass
+
if self._handle_new_device_update_new_data:
continue
else:
@@ -693,9 +725,16 @@ class DeviceHandler(DeviceWorkerHandler):
# Ignore any users that aren't ours
if self.hs.is_mine_id(user_id):
- joined_user_ids = await self.store.get_users_in_room(room_id)
- hosts = {get_domain_from_id(u) for u in joined_user_ids}
+ hosts = set(
+ await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
+ room_id
+ )
+ )
hosts.discard(self.server_name)
+ # For rooms with partial state, `hosts` is merely an
+ # approximation. When we transition to a full state room, we
+ # will have to send out device list updates to any servers we
+ # missed.
# Check if we've already sent this update to some hosts
if current_stream_id == stream_id:
@@ -705,7 +744,6 @@ class DeviceHandler(DeviceWorkerHandler):
user_id=user_id,
device_id=device_id,
room_id=room_id,
- stream_id=stream_id,
hosts=hosts,
context=opentracing_context,
)
@@ -739,18 +777,147 @@ class DeviceHandler(DeviceWorkerHandler):
hosts_already_sent_to.update(hosts)
current_stream_id = stream_id
+ # Advance `(stream_id, room_id)`.
+ _, _, room_id, stream_id, _ = rows[-1]
+ await self.store.set_device_change_last_converted_pos(
+ stream_id, room_id
+ )
+
finally:
self._handle_new_device_update_is_processing = False
+ async def handle_room_un_partial_stated(self, room_id: str) -> None:
+ """Handles sending appropriate device list updates in a room that has
+ gone from partial to full state.
+ """
+
+ # We defer to the device list updater to handle pending remote device
+ # list updates.
+ await self.device_list_updater.handle_room_un_partial_stated(room_id)
+
+ # Replay local updates.
+ (
+ join_event_id,
+ device_lists_stream_id,
+ ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
+ room_id
+ )
+
+ # Get the local device list changes that have happened in the room since
+ # we started joining. If there are no updates there's nothing left to do.
+ changes = await self.store.get_device_list_changes_in_room(
+ room_id, device_lists_stream_id
+ )
+ local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
+ if not local_changes:
+ return
+
+ # Note: We have persisted the full state at this point, we just haven't
+ # cleared the `partial_room` flag.
+ join_state_ids = await self._state_storage.get_state_ids_for_event(
+ join_event_id, await_full_state=False
+ )
+ current_state_ids = await self.store.get_partial_current_state_ids(room_id)
+
+ # Now we need to work out all servers that might have been in the room
+ # at any point during our join.
+
+ # First we look for any membership states that have changed between the
+ # initial join and now...
+ all_keys = set(join_state_ids)
+ all_keys.update(current_state_ids)
+
+ potentially_changed_hosts = set()
+ for etype, state_key in all_keys:
+ if etype != EventTypes.Member:
+ continue
+
+ prev = join_state_ids.get((etype, state_key))
+ current = current_state_ids.get((etype, state_key))
+
+ if prev != current:
+ potentially_changed_hosts.add(get_domain_from_id(state_key))
+
+ # ... then we add all the hosts that are currently joined to the room...
+ current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
+ potentially_changed_hosts.update(current_hosts_in_room)
+
+ # ... and finally we remove any hosts that we were told about, as we
+ # will have sent device list updates to those hosts when they happened.
+ known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
+ room_id
+ )
+ potentially_changed_hosts.difference_update(known_hosts_at_join)
+
+ potentially_changed_hosts.discard(self.server_name)
+
+ if not potentially_changed_hosts:
+ # Nothing to do.
+ return
+
+ logger.info(
+ "Found %d changed hosts to send device list updates to",
+ len(potentially_changed_hosts),
+ )
+
+ for user_id, device_id in local_changes:
+ await self.store.add_device_list_outbound_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ hosts=potentially_changed_hosts,
+ context=None,
+ )
+
+ # Notify things that device lists need to be sent out.
+ self.notifier.notify_replication()
+ for host in potentially_changed_hosts:
+ self.federation_sender.send_device_messages(host, immediate=False)
+
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+ device.update(
+ {
+ "last_seen_user_agent": ip.get("user_agent"),
+ "last_seen_ts": ip.get("last_seen"),
+ "last_seen_ip": ip.get("ip"),
+ }
+ )
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+ "Handles incoming device list updates from federation and contacts the main process over replication"
+
+ def __init__(self, hs: "HomeServer"):
+ from synapse.replication.http.devices import (
+ ReplicationUserDevicesResyncRestServlet,
+ )
+
+ self._user_device_resync_client = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+
+ async def user_device_resync(
+ self, user_id: str, mark_failed_as_stale: bool = True
+ ) -> Optional[JsonDict]:
+ """Fetches all devices for a user and updates the device cache with them.
+
+ Args:
+ user_id: The user's id whose device_list will be updated.
+ mark_failed_as_stale: Whether to mark the user's device list as stale
+ if the attempt to resync failed.
+ Returns:
+ A dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ """
+ return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
@@ -826,6 +993,19 @@ class DeviceListUpdater:
)
return
+ # Check if we are partially joining any rooms. If so we need to store
+ # all device list updates so that we can handle them correctly once we
+ # know who is in the room.
+ # TODO(faster joins): this fetches and processes a bunch of data that we don't
+ # use. Could be replaced by a tighter query e.g.
+ # SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+ partial_rooms = await self.store.get_partial_state_room_resync_info()
+ if partial_rooms:
+ await self.store.add_remote_device_list_to_pending(
+ user_id,
+ device_id,
+ )
+
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
@@ -1165,3 +1345,35 @@ class DeviceListUpdater:
device_ids.append(verify_key.version)
return device_ids
+
+ async def handle_room_un_partial_stated(self, room_id: str) -> None:
+ """Handles sending appropriate device list updates in a room that has
+ gone from partial to full state.
+ """
+
+ pending_updates = (
+ await self.store.get_pending_remote_device_list_updates_for_room(room_id)
+ )
+
+ for user_id, device_id in pending_updates:
+ logger.info(
+ "Got pending device list update in room %s: %s / %s",
+ room_id,
+ user_id,
+ device_id,
+ )
+ position = await self.store.add_device_change_to_streams(
+ user_id,
+ [device_id],
+ room_ids=[room_id],
+ )
+
+ if not position:
+ # This should only happen if there are no updates, which
+ # shouldn't happen when we've passed in a non-empty set of
+ # device IDs.
+ continue
+
+ self.device_handler.notifier.on_new_event(
+ StreamKeyType.DEVICE_LIST, position, rooms=[room_id]
+ )
|