diff --git a/changelog.d/13874.misc b/changelog.d/13874.misc
new file mode 100644
index 0000000000..499e488c35
--- /dev/null
+++ b/changelog.d/13874.misc
@@ -0,0 +1 @@
+Faster room joins: Send device list updates to most servers in rooms with partial state.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 901e2310b7..6566b3bf3d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -688,11 +688,15 @@ class DeviceHandler(DeviceWorkerHandler):
# Ignore any users that aren't ours
if self.hs.is_mine_id(user_id):
hosts = set(
- await self._storage_controllers.state.get_current_hosts_in_room(
+ await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
)
hosts.discard(self.server_name)
+ # For rooms with partial state, `hosts` is merely an
+ # approximation. When we transition to a full state room, we
+ # will have to send out device list updates to any servers we
+ # missed.
# Check if we've already sent this update to some hosts
if current_stream_id == stream_id:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index bbe568bf05..b1aa17047c 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,6 +23,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Tuple,
)
@@ -524,12 +525,53 @@ class StateStorageController:
return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
- """Get current hosts in room based on current state."""
+ """Get current hosts in room based on current state.
+
+ Blocks until we have full state for the given room. This only happens for rooms
+ with partial state.
+
+ Returns:
+ A list of hosts in the room, sorted by longest in the room first. (aka.
+ sorted by join with the lowest depth first).
+ """
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
+ async def get_current_hosts_in_room_or_partial_state_approximation(
+ self, room_id: str
+ ) -> Sequence[str]:
+ """Get approximation of current hosts in room based on current state.
+
+ For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
+ with the same order of results.
+
+ For rooms with partial state, no blocking occurs. Instead, the list of hosts
+ in the room at the time of joining is combined with the list of hosts which
+ joined the room afterwards. The returned list may include hosts that are not
+ actually in the room and exclude hosts that are in the room, since we may
+ calculate state incorrectly during the partial state phase. The order of results
+ is arbitrary for rooms with partial state.
+ """
+ # We have to read this list first to mitigate races with un-partial stating.
+ # This will be empty for rooms with full state.
+ hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
+ room_id
+ )
+
+ hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
+ hosts_from_state_set = set(hosts_from_state)
+
+ # First take the list of hosts based on the current state.
+ # For rooms with partial state, this will be missing most hosts.
+ hosts = list(hosts_from_state)
+ # Then add in the list of hosts in the room at the time we joined.
+ # This will be an empty list for rooms with full state.
+ hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
+
+ return hosts
+
async def get_users_in_room_with_profiles(
self, room_id: str
) -> Dict[str, ProfileInfo]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index bef66f1992..5dd116d766 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -25,6 +25,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Tuple,
Union,
cast,
@@ -1133,6 +1134,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
+ async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
+ """Gets the list of servers in a partial state room at the time we joined it.
+
+ Returns:
+ The `servers_in_room` list from the `/send_join` response for partial state
+ rooms. May not be accurate or complete, as it comes from a remote
+ homeserver.
+ An empty list for full state rooms.
+ """
+ return await self.db_pool.simple_select_onecol(
+ "partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ retcol="server_name",
+ desc="get_partial_state_servers_at_join",
+ )
+
async def get_partial_state_rooms_and_servers(
self,
) -> Mapping[str, Collection[str]]:
|