diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index e41c99027a..7d97f8f60e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -97,6 +97,12 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PartialStateResyncInfo:
+ joined_via: Optional[str]
+ servers_in_room: List[str] = attr.ib(factory=list)
+
+
class RoomWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -1160,17 +1166,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
desc="get_partial_state_servers_at_join",
)
- async def get_partial_state_rooms_and_servers(
+ async def get_partial_state_room_resync_info(
self,
- ) -> Mapping[str, Collection[str]]:
- """Get all rooms containing events with partial state, and the servers known
- to be in the room.
+ ) -> Mapping[str, PartialStateResyncInfo]:
+ """Get all rooms containing events with partial state, and the information
+ needed to restart a "resync" of those rooms.
Returns:
A dictionary of rooms with partial state, with room IDs as keys and
lists of servers in rooms as values.
"""
- room_servers: Dict[str, List[str]] = {}
+ room_servers: Dict[str, PartialStateResyncInfo] = {}
+
+ rows = await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ )
+
+ for row in rows:
+ room_id = row["room_id"]
+ joined_via = row["joined_via"]
+ room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
rows = await self.db_pool.simple_select_list(
"partial_state_rooms_servers",
@@ -1182,7 +1200,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
for row in rows:
room_id = row["room_id"]
server_name = row["server_name"]
- room_servers.setdefault(room_id, []).append(server_name)
+ entry = room_servers.get(room_id)
+ if entry is None:
+ # There is a foreign key constraint which enforces that every room_id in
+ # partial_state_rooms_servers appears in partial_state_rooms. So we
+ # expect `entry` to be non-null. (This reasoning fails if we've
+ # partial-joined between the two SELECTs, but this is unlikely to happen
+ # in practice.)
+ continue
+ entry.servers_in_room.append(server_name)
return room_servers
@@ -1827,6 +1853,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id: str,
servers: Collection[str],
device_lists_stream_id: int,
+ joined_via: str,
) -> None:
"""Mark the given room as containing events with partial state.
@@ -1842,6 +1869,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
servers: other servers known to be in the room
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
+ joined_via: the server name we requested a partial join from.
"""
await self.db_pool.runInteraction(
"store_partial_state_room",
@@ -1849,6 +1877,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id,
servers,
device_lists_stream_id,
+ joined_via,
)
def _store_partial_state_room_txn(
@@ -1857,6 +1886,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id: str,
servers: Collection[str],
device_lists_stream_id: int,
+ joined_via: str,
) -> None:
DatabasePool.simple_insert_txn(
txn,
@@ -1866,6 +1896,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
"device_lists_stream_id": device_lists_stream_id,
# To be updated later once the join event is persisted.
"join_event_id": None,
+ "joined_via": joined_via,
},
)
DatabasePool.simple_insert_many_txn(
|