diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1151fb0cc3..1e562d4a40 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1995,3 +1995,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
add_device_list_outbound_pokes_txn,
stream_ids,
)
+
+ async def add_remote_device_list_to_pending(
+ self, user_id: str, device_id: str
+ ) -> None:
+ """Add a device list update to the table tracking remote device list
+ updates during partial joins.
+ """
+
+ async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ await self.db_pool.simple_upsert(
+ table="device_lists_remote_pending",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={"stream_id": stream_id},
+ desc="add_remote_device_list_to_pending",
+ )
+
+ async def get_pending_remote_device_list_updates_for_room(
+ self, room_id: str
+ ) -> Collection[Tuple[str, str]]:
+ """Get the set of remote device list updates from the pending table for
+ the room.
+ """
+
+ min_device_stream_id = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="device_lists_stream_id",
+ desc="get_pending_remote_device_list_updates_for_room_device",
+ )
+
+ sql = """
+ SELECT user_id, device_id FROM device_lists_remote_pending AS d
+ INNER JOIN current_state_events AS c ON
+ type = 'm.room.member'
+ AND state_key = user_id
+ AND membership = 'join'
+ WHERE
+ room_id = ? AND stream_id > ?
+ """
+
+ def get_pending_remote_device_list_updates_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Collection[Tuple[str, str]]:
+ txn.execute(sql, (room_id, min_device_stream_id))
+ return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction(
+ "get_pending_remote_device_list_updates_for_room",
+ get_pending_remote_device_list_updates_for_room_txn,
+ )
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 064c332fb7..672c9a03fc 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1217,6 +1217,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ # We now delete anything from `device_lists_remote_pending` with a
+ # stream ID less than the minimum
+ # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
+ device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={},
+ retcol="MIN(device_lists_stream_id)",
+ allow_none=True,
+ )
+ if device_lists_stream_id is None:
+ # There are no rooms being currently partially joined, so we delete everything.
+ txn.execute("DELETE FROM device_lists_remote_pending")
+ else:
+ sql = """
+ DELETE FROM device_lists_remote_pending
+ WHERE stream_id <= ?
+ """
+ txn.execute(sql, (device_lists_stream_id,))
+
@cached()
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
|