diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b2a67aff89..5188b2f7a4 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -41,7 +41,7 @@ from typing import (
import attr
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
@@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_event(create_id)
return create_event
+ @cached(max_entries=10000)
+ async def get_room_type(self, room_id: str) -> Optional[str]:
+ """Get the room type for a given room. The server must be joined to the
+ given room.
+ """
+
+ row = await self.db_pool.simple_select_one(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ retcols=("room_type",),
+ allow_none=True,
+ desc="get_room_type",
+ )
+
+ if row is not None:
+ return row[0]
+
+ # If we haven't updated `room_stats_state` with the room yet, query the
+ # create event directly.
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ return room_type
+
+ @cachedList(cached_method_name="get_room_type", list_name="room_ids")
+ async def bulk_get_room_type(
+ self, room_ids: Set[str]
+ ) -> Mapping[str, Optional[str]]:
+ """Bulk fetch room types for the given rooms, the server must be in all
+ the rooms given.
+ """
+
+ rows = await self.db_pool.simple_select_many_batch(
+ table="room_stats_state",
+ column="room_id",
+ iterable=room_ids,
+ retcols=("room_id", "room_type"),
+ desc="bulk_get_room_type",
+ )
+
+ # If we haven't updated `room_stats_state` with the room yet, query the
+ # create events directly. This should happen only rarely so we don't
+ # mind if we do this in a loop.
+ results = dict(rows)
+ for room_id in room_ids - results.keys():
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ results[room_id] = room_type
+
+ return results
+
@cached(max_entries=100000, iterable=True)
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
|