diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 961f8eb186..2567954679 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -273,11 +273,9 @@ class DeviceWorkerHandler:
possibly_left = possibly_changed | possibly_left
# Double check if we still share rooms with the given user.
- users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
- possibly_left
- )
+ users_rooms = await self.store.get_rooms_for_users(possibly_left)
for changed_user_id, entries in users_rooms.items():
- if any(e.room_id in room_ids for e in entries):
+ if any(rid in room_ids for rid in entries):
possibly_left.discard(changed_user_id)
else:
possibly_joined.discard(changed_user_id)
@@ -309,6 +307,17 @@ class DeviceWorkerHandler:
"self_signing_key": self_signing_key,
}
+ async def handle_room_un_partial_stated(self, room_id: str) -> None:
+ """Handles sending appropriate device list updates in a room that has
+ gone from partial to full state.
+ """
+
+ # TODO(faster_joins): worker mode support
+ # https://github.com/matrix-org/synapse/issues/12994
+ logger.error(
+ "Trying handling device list state for partial join: not supported on workers."
+ )
+
class DeviceHandler(DeviceWorkerHandler):
def __init__(self, hs: "HomeServer"):
@@ -746,6 +755,95 @@ class DeviceHandler(DeviceWorkerHandler):
finally:
self._handle_new_device_update_is_processing = False
+ async def handle_room_un_partial_stated(self, room_id: str) -> None:
+ """Handles sending appropriate device list updates in a room that has
+ gone from partial to full state.
+ """
+
+ # We defer to the device list updater to handle pending remote device
+ # list updates.
+ await self.device_list_updater.handle_room_un_partial_stated(room_id)
+
+ # Replay local updates.
+ (
+ join_event_id,
+ device_lists_stream_id,
+ ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
+ room_id
+ )
+
+ # Get the local device list changes that have happened in the room since
+ # we started joining. If there are no updates there's nothing left to do.
+ changes = await self.store.get_device_list_changes_in_room(
+ room_id, device_lists_stream_id
+ )
+ local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
+ if not local_changes:
+ return
+
+ # Note: We have persisted the full state at this point, we just haven't
+ # cleared the `partial_room` flag.
+ join_state_ids = await self._state_storage.get_state_ids_for_event(
+ join_event_id, await_full_state=False
+ )
+ current_state_ids = await self.store.get_partial_current_state_ids(room_id)
+
+ # Now we need to work out all servers that might have been in the room
+ # at any point during our join.
+
+ # First we look for any membership states that have changed between the
+ # initial join and now...
+ all_keys = set(join_state_ids)
+ all_keys.update(current_state_ids)
+
+ potentially_changed_hosts = set()
+ for etype, state_key in all_keys:
+ if etype != EventTypes.Member:
+ continue
+
+ prev = join_state_ids.get((etype, state_key))
+ current = current_state_ids.get((etype, state_key))
+
+ if prev != current:
+ potentially_changed_hosts.add(get_domain_from_id(state_key))
+
+ # ... then we add all the hosts that are currently joined to the room...
+ current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
+ potentially_changed_hosts.update(current_hosts_in_room)
+
+ # ... and finally we remove any hosts that we were told about, as we
+ # will have sent device list updates to those hosts when they happened.
+ known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
+ room_id
+ )
+ potentially_changed_hosts.difference_update(known_hosts_at_join)
+
+ potentially_changed_hosts.discard(self.server_name)
+
+ if not potentially_changed_hosts:
+ # Nothing to do.
+ return
+
+ logger.info(
+ "Found %d changed hosts to send device list updates to",
+ len(potentially_changed_hosts),
+ )
+
+ for user_id, device_id in local_changes:
+ await self.store.add_device_list_outbound_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ stream_id=None,
+ hosts=potentially_changed_hosts,
+ context=None,
+ )
+
+ # Notify things that device lists need to be sent out.
+ self.notifier.notify_replication()
+ for host in potentially_changed_hosts:
+ self.federation_sender.send_device_messages(host, immediate=False)
+
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@@ -836,6 +934,19 @@ class DeviceListUpdater:
)
return
+ # Check if we are partially joining any rooms. If so we need to store
+ # all device list updates so that we can handle them correctly once we
+ # know who is in the room.
+ # TODO(faster joins): this fetches and processes a bunch of data that we don't
+ # use. Could be replaced by a tighter query e.g.
+ # SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+ partial_rooms = await self.store.get_partial_state_room_resync_info()
+ if partial_rooms:
+ await self.store.add_remote_device_list_to_pending(
+ user_id,
+ device_id,
+ )
+
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
@@ -1175,3 +1286,35 @@ class DeviceListUpdater:
device_ids.append(verify_key.version)
return device_ids
+
+ async def handle_room_un_partial_stated(self, room_id: str) -> None:
+ """Handles sending appropriate device list updates in a room that has
+ gone from partial to full state.
+ """
+
+ pending_updates = (
+ await self.store.get_pending_remote_device_list_updates_for_room(room_id)
+ )
+
+ for user_id, device_id in pending_updates:
+ logger.info(
+ "Got pending device list update in room %s: %s / %s",
+ room_id,
+ user_id,
+ device_id,
+ )
+ position = await self.store.add_device_change_to_streams(
+ user_id,
+ [device_id],
+ room_ids=[room_id],
+ )
+
+ if not position:
+ # This should only happen if there are no updates, which
+ # shouldn't happen when we've passed in a non-empty set of
+ # device IDs.
+ continue
+
+ self.device_handler.notifier.on_new_event(
+ StreamKeyType.DEVICE_LIST, position, rooms=[room_id]
+ )
|