diff --git a/changelog.d/13934.misc b/changelog.d/13934.misc
new file mode 100644
index 0000000000..6610a9f567
--- /dev/null
+++ b/changelog.d/13934.misc
@@ -0,0 +1 @@
+Correctly handle sending local device list updates to remote servers during a partial join.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f2ef591103..03082fce42 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -762,10 +762,90 @@ class DeviceHandler(DeviceWorkerHandler):
gone from partial to full state.
"""
- # We defer to the device list updater implementation as we're on the
- # right worker.
+ # We defer to the device list updater to handle pending remote device
+ # list updates.
await self.device_list_updater.handle_room_un_partial_stated(room_id)
+ # Replay local updates.
+ (
+ join_event_id,
+ device_lists_stream_id,
+ ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
+ room_id
+ )
+
+ # Get the local device list changes that have happened in the room since
+ # we started joining. If there are no updates there's nothing left to do.
+ changes = await self.store.get_device_list_changes_in_room(
+ room_id, device_lists_stream_id
+ )
+ local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
+ if not local_changes:
+ return
+
+ # Note: We have persisted the full state at this point, we just haven't
+ # cleared the `partial_room` flag.
+ join_state_ids = await self._state_storage.get_state_ids_for_event(
+ join_event_id, await_full_state=False
+ )
+ current_state_ids = await self.store.get_partial_current_state_ids(room_id)
+
+ # Now we need to work out all servers that might have been in the room
+ # at any point during our join.
+
+ # First we look for any membership states that have changed between the
+ # initial join and now...
+ all_keys = set(join_state_ids)
+ all_keys.update(current_state_ids)
+
+ potentially_changed_hosts = set()
+ for etype, state_key in all_keys:
+ if etype != EventTypes.Member:
+ continue
+
+ prev = join_state_ids.get((etype, state_key))
+ current = current_state_ids.get((etype, state_key))
+
+ if prev != current:
+ potentially_changed_hosts.add(get_domain_from_id(state_key))
+
+ # ... then we add all the hosts that are currently joined to the room...
+ current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
+ potentially_changed_hosts.update(current_hosts_in_room)
+
+ # ... and finally we remove any hosts that we were told about, as we
+ # will have sent device list updates to those hosts when they happened.
+ known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
+ room_id
+ )
+ potentially_changed_hosts.difference_update(known_hosts_at_join)
+
+ potentially_changed_hosts.discard(self.server_name)
+
+ if not potentially_changed_hosts:
+ # Nothing to do.
+ return
+
+ logger.info(
+ "Found %d changed hosts to send device list updates to",
+ len(potentially_changed_hosts),
+ )
+
+ for user_id, device_id in local_changes:
+ await self.store.add_device_list_outbound_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ stream_id=None,
+ hosts=potentially_changed_hosts,
+ context=None,
+ )
+
+ # Notify things that device lists need to be sent out.
+ self.notifier.notify_replication()
+ for host in potentially_changed_hosts:
+ self.federation_sender.send_device_messages(host, immediate=False)
+
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1e562d4a40..18358eca46 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes
+ async def get_device_list_changes_in_room(
+ self, room_id: str, min_stream_id: int
+ ) -> Collection[Tuple[str, str]]:
+ """Get all device list changes that happened in the room since the given
+ stream ID.
+
+ Returns:
+ Collection of user ID/device ID tuples of all devices that have
+ changed
+ """
+
+ sql = """
+ SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
+ WHERE room_id = ? AND stream_id > ?
+ """
+
+ def get_device_list_changes_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Collection[Tuple[str, str]]:
+ txn.execute(sql, (room_id, min_stream_id))
+ return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction(
+ "get_device_list_changes_in_room",
+ get_device_list_changes_in_room_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
- stream_id: int,
+ stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
- Marks the associated row in `device_lists_changes_in_room` as handled.
+ Marks the associated row in `device_lists_changes_in_room` as handled,
+ if `stream_id` is provided.
"""
def add_device_list_outbound_pokes_txn(
@@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context=context,
)
- self.db_pool.simple_update_txn(
- txn,
- table="device_lists_changes_in_room",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "room_id": room_id,
- },
- updatevalues={"converted_to_destinations": True},
- )
+ if stream_id:
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 672c9a03fc..059eef5c22 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1256,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return entry is not None
+ async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
+ self, room_id: str
+ ) -> Tuple[str, int]:
+ """Get the event ID of the initial join that started the partial
+ join, and the device list stream ID at the point we started the partial
+ join.
+ """
+
+ result = await self.db_pool.simple_select_one(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("join_event_id", "device_lists_stream_id"),
+ desc="get_join_event_id_for_partial_state",
+ )
+ return result["join_event_id"], result["device_lists_stream_id"]
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|