diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 62bc4600fb..788f7d1e32 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -308,8 +308,24 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=10000)
- async def get_room_type(self, room_id: str) -> Optional[str]:
- raise NotImplementedError()
+ async def get_room_type(self, room_id: str) -> Union[Optional[str], Sentinel]:
+ """Fetch room type for given room.
+
+ Since this function is cached, any missing values would be cached as
+ `None`. In order to distinguish between an unencrypted room that has
+ `None` encryption and a room that is unknown to the server where we
+ might want to omit the value (which would make it cached as `None`),
+ instead we use the sentinel value `ROOM_UNKNOWN_SENTINEL`.
+ """
+
+ try:
+ create_event = await self.get_create_event_for_room(room_id)
+ return create_event.content.get(EventContentFields.ROOM_TYPE)
+ except NotFoundError:
+ # We use the sentinel value to distinguish between `None` which is a
+ # valid room type and a room that is unknown to the server so the value
+ # is just unset.
+ return ROOM_UNKNOWN_SENTINEL
@cachedList(cached_method_name="get_room_type", list_name="room_ids")
async def bulk_get_room_type(
@@ -535,7 +551,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="check_if_events_in_current_state",
)
- return frozenset(event_id for event_id, in rows)
+ return frozenset(event_id for (event_id,) in rows)
# FIXME: how should this be cached?
@cancellable
@@ -556,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Map from type/state_key to event ID.
"""
+ if state_filter is None:
+ state_filter = StateFilter.all()
- where_clause, where_args = (
- state_filter or StateFilter.all()
- ).make_sql_filter_clause()
+ where_clause, where_args = (state_filter).make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
@@ -568,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
- results = StateMapWrapper(state_filter=state_filter or StateFilter.all())
+ results = StateMapWrapper(state_filter=state_filter)
sql = """
SELECT type, state_key, event_id FROM current_state_events
@@ -665,7 +681,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
context: EventContext,
) -> None:
"""Update the state group for a partial state event"""
- async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id:
+ async with (
+ self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id
+ ):
await self.db_pool.runInteraction(
"update_state_for_partial_state_event",
self._update_state_for_partial_state_event_txn,
@@ -736,6 +754,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
+ MEMBERS_CURRENT_STATE_UPDATE_NAME = "current_state_events_members_room_index"
def __init__(
self,
@@ -764,6 +783,13 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self.DELETE_CURRENT_STATE_UPDATE_NAME,
self._background_remove_left_rooms,
)
+ self.db_pool.updates.register_background_index_update(
+ self.MEMBERS_CURRENT_STATE_UPDATE_NAME,
+ index_name="current_state_events_members_room_index",
+ table="current_state_events",
+ columns=["room_id", "membership"],
+ where_clause="type='m.room.member'",
+ )
async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
|