summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/databases/main/state.py42
1 files changed, 34 insertions, 8 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py

index 62bc4600fb..788f7d1e32 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -308,8 +308,24 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=10000) - async def get_room_type(self, room_id: str) -> Optional[str]: - raise NotImplementedError() + async def get_room_type(self, room_id: str) -> Union[Optional[str], Sentinel]: + """Fetch room type for given room. + + Since this function is cached, any missing values would be cached as + `None`. In order to distinguish between an unencrypted room that has + `None` encryption and a room that is unknown to the server where we + might want to omit the value (which would make it cached as `None`), + instead we use the sentinel value `ROOM_UNKNOWN_SENTINEL`. + """ + + try: + create_event = await self.get_create_event_for_room(room_id) + return create_event.content.get(EventContentFields.ROOM_TYPE) + except NotFoundError: + # We use the sentinel value to distinguish between `None` which is a + # valid room type and a room that is unknown to the server so the value + # is just unset. + return ROOM_UNKNOWN_SENTINEL @cachedList(cached_method_name="get_room_type", list_name="room_ids") async def bulk_get_room_type( @@ -535,7 +551,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="check_if_events_in_current_state", ) - return frozenset(event_id for event_id, in rows) + return frozenset(event_id for (event_id,) in rows) # FIXME: how should this be cached? @cancellable @@ -556,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Map from type/state_key to event ID. """ + if state_filter is None: + state_filter = StateFilter.all() - where_clause, where_args = ( - state_filter or StateFilter.all() - ).make_sql_filter_clause() + where_clause, where_args = (state_filter).make_sql_filter_clause() if not where_clause: # We delegate to the cached version @@ -568,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + results = StateMapWrapper(state_filter=state_filter) sql = """ SELECT type, state_key, event_id FROM current_state_events @@ -665,7 +681,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): context: EventContext, ) -> None: """Update the state group for a partial state event""" - async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id: + async with ( + self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id + ): await self.db_pool.runInteraction( "update_state_for_partial_state_event", self._update_state_for_partial_state_event_txn, @@ -736,6 +754,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" + MEMBERS_CURRENT_STATE_UPDATE_NAME = "current_state_events_members_room_index" def __init__( self, @@ -764,6 +783,13 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, ) + self.db_pool.updates.register_background_index_update( + self.MEMBERS_CURRENT_STATE_UPDATE_NAME, + index_name="current_state_events_members_room_index", + table="current_state_events", + columns=["room_id", "membership"], + where_clause="type='m.room.member'", + ) async def _background_remove_left_rooms( self, progress: JsonDict, batch_size: int