diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 2045169b9a..52efd4a171 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -569,10 +569,11 @@ class StateStorageController:
is arbitrary for rooms with partial state.
"""
# We have to read this list first to mitigate races with un-partial stating.
- # This will be empty for rooms with full state.
hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
room_id
)
+ if hosts_at_join is None:
+ hosts_at_join = frozenset()
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4ddb27f686..644bbb8878 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -18,6 +18,7 @@ from abc import abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Awaitable,
Collection,
@@ -25,7 +26,6 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
Set,
Tuple,
Union,
@@ -109,7 +109,7 @@ class RoomSortOrder(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PartialStateResyncInfo:
joined_via: Optional[str]
- servers_in_room: List[str] = attr.ib(factory=list)
+ servers_in_room: Set[str] = attr.ib(factory=set)
class RoomWorkerStore(CacheInvalidationWorkerStore):
@@ -1193,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
- @cached(iterable=True)
- async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
- """Gets the list of servers in a partial state room at the time we joined it.
+ async def get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> Optional[AbstractSet[str]]:
+ """Gets the set of servers in a partial state room at the time we joined it.
Returns:
The `servers_in_room` list from the `/send_join` response for partial state
rooms. May not be accurate or complete, as it comes from a remote
homeserver.
- An empty list for full state rooms.
+ `None` for full state rooms.
"""
- return await self.db_pool.simple_select_onecol(
- "partial_state_rooms_servers",
- keyvalues={"room_id": room_id},
- retcol="server_name",
- desc="get_partial_state_servers_at_join",
+ servers_in_room = await self._get_partial_state_servers_at_join(room_id)
+
+ if len(servers_in_room) == 0:
+ return None
+
+ return servers_in_room
+
+ @cached(iterable=True)
+ async def _get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> AbstractSet[str]:
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ "partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ retcol="server_name",
+ desc="get_partial_state_servers_at_join",
+ )
)
async def get_partial_state_room_resync_info(
@@ -1252,7 +1266,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# partial-joined between the two SELECTs, but this is unlikely to happen
# in practice.)
continue
- entry.servers_in_room.append(server_name)
+ entry.servers_in_room.add(server_name)
return room_servers
@@ -1942,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
async def store_partial_state_room(
self,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1957,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
room_id: the ID of the room
- servers: other servers known to be in the room
+ servers: other servers known to be in the room. must include `joined_via`.
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
joined_via: the server name we requested a partial join from.
"""
+ assert joined_via in servers
+
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
@@ -1975,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
txn: LoggingTransaction,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1998,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
async def write_partial_state_rooms_join_event_id(
@@ -2409,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
DatabasePool.simple_insert_txn(
|