From 606da398fc4c693f2e75b9520264e0fc51d03581 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Jul 2024 16:00:44 +0100 Subject: Fix filtering room types on remote rooms (#17434) We can only fetch room types for rooms the server is in, so we need to only filter rooms that we're joined to. Also includes a perf fix to bulk fetch room types. --- synapse/handlers/sliding_sync.py | 22 +++++++------- synapse/storage/databases/main/state.py | 52 ++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 13 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 818b13621c..8e6c2fb860 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple import attr from immutabledict import immutabledict -from synapse.api.constants import ( - AccountDataTypes, - Direction, - EventContentFields, - EventTypes, - Membership, -) +from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership from synapse.events import EventBase from synapse.events.utils import strip_event from synapse.handlers.relations import BundledAggregations @@ -959,11 +953,15 @@ class SlidingSyncHandler: # provided in the list. `None` is a valid type for rooms which do not have a # room type. if filters.room_types is not None or filters.not_room_types is not None: - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - create_event = await self.store.get_create_event_for_room(room_id) - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + room_to_type = await self.store.bulk_get_room_type( + { + room_id + for room_id in filtered_room_id_set + # We only know the room types for joined rooms + if sync_room_map[room_id].membership == Membership.JOIN + } + ) + for room_id, room_type in room_to_type.items(): if ( filters.room_types is not None and room_type not in filters.room_types diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index b2a67aff89..5188b2f7a4 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -41,7 +41,7 @@ from typing import ( import attr -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase @@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_event(create_id) return create_event + @cached(max_entries=10000) + async def get_room_type(self, room_id: str) -> Optional[str]: + """Get the room type for a given room. The server must be joined to the + given room. + """ + + row = await self.db_pool.simple_select_one( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcols=("room_type",), + allow_none=True, + desc="get_room_type", + ) + + if row is not None: + return row[0] + + # If we haven't updated `room_stats_state` with the room yet, query the + # create event directly. + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + return room_type + + @cachedList(cached_method_name="get_room_type", list_name="room_ids") + async def bulk_get_room_type( + self, room_ids: Set[str] + ) -> Mapping[str, Optional[str]]: + """Bulk fetch room types for the given rooms, the server must be in all + the rooms given. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="room_stats_state", + column="room_id", + iterable=room_ids, + retcols=("room_id", "room_type"), + desc="bulk_get_room_type", + ) + + # If we haven't updated `room_stats_state` with the room yet, query the + # create events directly. This should happen only rarely so we don't + # mind if we do this in a loop. + results = dict(rows) + for room_id in room_ids - results.keys(): + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + results[room_id] = room_type + + return results + @cached(max_entries=100000, iterable=True) async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the -- cgit 1.4.1