diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 68b0806041..e0b7b7e194 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1382,6 +1382,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
partial_state_rooms = {row[0] for row in rows}
return {room_id: room_id in partial_state_rooms for room_id in room_ids}
+ @cached(max_entries=10000, iterable=True)
+ async def get_partial_rooms(self) -> AbstractSet[str]:
+ """Get any "partial-state" rooms which the user is in.
+
+ This is fast as the set of partially stated rooms at any point across
+ the whole server is small, and so such a query is fast. This is also
+ faster than looking up whether a set of room ID's are partially stated
+ via `is_partial_state_room_batched(...)` because of the sheer amount of
+ CPU time looking all the rooms up in the cache.
+ """
+
+ def _get_partial_rooms_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> AbstractSet[str]:
+ sql = """
+ SELECT room_id FROM partial_state_rooms
+ """
+ txn.execute(sql)
+ return {room_id for (room_id,) in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_partial_rooms_for_user", _get_partial_rooms_for_user_txn
+ )
+
async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
@@ -2341,6 +2365,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._invalidate_cache_and_stream(
txn, self._get_partial_state_servers_at_join, (room_id,)
)
+ self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms)
async def write_partial_state_rooms_join_event_id(
self,
@@ -2562,6 +2587,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self._invalidate_cache_and_stream(
txn, self._get_partial_state_servers_at_join, (room_id,)
)
+ self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms)
DatabasePool.simple_insert_txn(
txn,
|