diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index bb60130afe..2b31ce54bb 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,7 +23,7 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
+ Set,
Tuple,
)
@@ -529,7 +529,18 @@ class StateStorageController:
)
return state_map.get(key)
- async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ """Get current hosts in room based on current state.
+
+ Blocks until we have full state for the given room. This only happens for rooms
+ with partial state.
+ """
+
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_current_hosts_in_room(room_id)
+
+ async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@@ -542,11 +553,11 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
- return await self.stores.main.get_current_hosts_in_room(room_id)
+ return await self.stores.main.get_current_hosts_in_room_ordered(room_id)
async def get_current_hosts_in_room_or_partial_state_approximation(
self, room_id: str
- ) -> Sequence[str]:
+ ) -> Collection[str]:
"""Get approximation of current hosts in room based on current state.
For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
@@ -566,14 +577,9 @@ class StateStorageController:
)
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
- hosts_from_state_set = set(hosts_from_state)
-
- # First take the list of hosts based on the current state.
- # For rooms with partial state, this will be missing most hosts.
- hosts = list(hosts_from_state)
- # Then add in the list of hosts in the room at the time we joined.
- # This will be an empty list for rooms with full state.
- hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
+
+ hosts = set(hosts_at_join)
+ hosts.update(hosts_from_state)
return hosts
|