diff options
author | Eric Eastwood <eric.eastwood@beta.gouv.fr> | 2024-07-30 13:20:29 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-30 13:20:29 -0500 |
commit | 46de0ee16be8731f0ed68654edc75aced1510b19 (patch) | |
tree | b1538d66ed8ed2651c7db671fe76231542aa33d9 /synapse | |
parent | Sliding Sync: Add receipts extension (MSC3960) (#17489) (diff) | |
download | synapse-46de0ee16be8731f0ed68654edc75aced1510b19.tar.xz |
Sliding Sync: Update filters to be robust against remote invite rooms (#17450)
Update `filters.is_encrypted` and `filters.types`/`filters.not_types` to be robust when dealing with remote invite rooms in Sliding Sync. Part of [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync Follow-up to https://github.com/element-hq/synapse/pull/17434 We now take into account current state, fallback to stripped state for invite/knock rooms, then historical state. If we can't determine the info needed to filter a room (either from state or stripped state), it is filtered out.
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/api/constants.py | 8 | ||||
-rw-r--r-- | synapse/events/__init__.py | 19 | ||||
-rw-r--r-- | synapse/events/utils.py | 29 | ||||
-rw-r--r-- | synapse/handlers/sliding_sync.py | 385 | ||||
-rw-r--r-- | synapse/handlers/stats.py | 4 | ||||
-rw-r--r-- | synapse/storage/_base.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/cache.py | 22 | ||||
-rw-r--r-- | synapse/storage/databases/main/state.py | 215 |
8 files changed, 611 insertions, 75 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 85001d9676..7dcb1e01fd 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -225,6 +225,11 @@ class EventContentFields: # This is deprecated in MSC2175. ROOM_CREATOR: Final = "creator" + # The version of the room for `m.room.create` events. + ROOM_VERSION: Final = "room_version" + + ROOM_NAME: Final = "name" + # Used in m.room.guest_access events. GUEST_ACCESS: Final = "guest_access" @@ -237,6 +242,9 @@ class EventContentFields: # an unspecced field added to to-device messages to identify them uniquely-ish TO_DEVICE_MSGID: Final = "org.matrix.msgid" + # `m.room.encryption`` algorithm field + ENCRYPTION_ALGORITHM: Final = "algorithm" + class EventUnsignedContentFields: """Fields found inside the 'unsigned' data on events""" diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 36e0f47e51..2e56b671f0 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -554,3 +554,22 @@ def relation_from_event(event: EventBase) -> Optional[_EventRelation]: aggregation_key = None return _EventRelation(parent_id, rel_type, aggregation_key) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StrippedStateEvent: + """ + A stripped down state event. Usually used for remote invite/knocks so the user can + make an informed decision on whether they want to join. + + Attributes: + type: Event `type` + state_key: Event `state_key` + sender: Event `sender` + content: Event `content` + """ + + type: str + state_key: str + sender: str + content: Dict[str, Any] diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f937fd4698..54f94add4d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -49,7 +49,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict, Requester -from . import EventBase, make_event_from_dict +from . import EventBase, StrippedStateEvent, make_event_from_dict if TYPE_CHECKING: from synapse.handlers.relations import BundledAggregations @@ -854,3 +854,30 @@ def strip_event(event: EventBase) -> JsonDict: "content": event.content, "sender": event.sender, } + + +def parse_stripped_state_event(raw_stripped_event: Any) -> Optional[StrippedStateEvent]: + """ + Given a raw value from an event's `unsigned` field, attempt to parse it into a + `StrippedStateEvent`. + """ + if isinstance(raw_stripped_event, dict): + # All of these fields are required + type = raw_stripped_event.get("type") + state_key = raw_stripped_event.get("state_key") + sender = raw_stripped_event.get("sender") + content = raw_stripped_event.get("content") + if ( + isinstance(type, str) + and isinstance(state_key, str) + and isinstance(sender, str) + and isinstance(content, dict) + ): + return StrippedStateEvent( + type=type, + state_key=state_key, + sender=sender, + content=content, + ) + + return None diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 7a734f6712..530e7b7b4e 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -17,6 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # +import enum import logging from enum import Enum from itertools import chain @@ -26,23 +27,35 @@ from typing import ( Dict, Final, List, + Literal, Mapping, Optional, Sequence, Set, Tuple, + Union, ) import attr from immutabledict import immutabledict from typing_extensions import assert_never -from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership -from synapse.events import EventBase -from synapse.events.utils import strip_event +from synapse.api.constants import ( + AccountDataTypes, + Direction, + EventContentFields, + EventTypes, + Membership, +) +from synapse.events import EventBase, StrippedStateEvent +from synapse.events.utils import parse_stripped_state_event, strip_event from synapse.handlers.relations import BundledAggregations from synapse.logging.opentracing import log_kv, start_active_span, tag_args, trace from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary +from synapse.storage.databases.main.state import ( + ROOM_UNKNOWN_SENTINEL, + Sentinel as StateSentinel, +) from synapse.storage.databases.main.stream import CurrentStateDeltaMembership from synapse.storage.roommember import MemberSummary from synapse.types import ( @@ -50,6 +63,7 @@ from synapse.types import ( JsonDict, JsonMapping, MultiWriterStreamToken, + MutableStateMap, PersistedEventPosition, Requester, RoomStreamToken, @@ -71,6 +85,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup and subsequent type narrowing. + UNSET_SENTINEL = object() + + # The event types that clients should consider as new activity. DEFAULT_BUMP_EVENT_TYPES = { EventTypes.Create, @@ -1172,6 +1192,265 @@ class SlidingSyncHandler: # return None + async def _bulk_get_stripped_state_for_rooms_from_sync_room_map( + self, + room_ids: StrCollection, + sync_room_map: Dict[str, _RoomMembershipForUser], + ) -> Dict[str, Optional[StateMap[StrippedStateEvent]]]: + """ + Fetch stripped state for a list of room IDs. Stripped state is only + applicable to invite/knock rooms. Other rooms will have `None` as their + stripped state. + + For invite rooms, we pull from `unsigned.invite_room_state`. + For knock rooms, we pull from `unsigned.knock_room_state`. + + Args: + room_ids: Room IDs to fetch stripped state for + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + + Returns: + Mapping from room_id to mapping of (type, state_key) to stripped state + event. + """ + room_id_to_stripped_state_map: Dict[ + str, Optional[StateMap[StrippedStateEvent]] + ] = {} + + # Fetch what we haven't before + room_ids_to_fetch = [ + room_id + for room_id in room_ids + if room_id not in room_id_to_stripped_state_map + ] + + # Gather a list of event IDs we can grab stripped state from + invite_or_knock_event_ids: List[str] = [] + for room_id in room_ids_to_fetch: + if sync_room_map[room_id].membership in ( + Membership.INVITE, + Membership.KNOCK, + ): + event_id = sync_room_map[room_id].event_id + # If this is an invite/knock then there should be an event_id + assert event_id is not None + invite_or_knock_event_ids.append(event_id) + else: + room_id_to_stripped_state_map[room_id] = None + + invite_or_knock_events = await self.store.get_events(invite_or_knock_event_ids) + for invite_or_knock_event in invite_or_knock_events.values(): + room_id = invite_or_knock_event.room_id + membership = invite_or_knock_event.membership + + raw_stripped_state_events = None + if membership == Membership.INVITE: + invite_room_state = invite_or_knock_event.unsigned.get( + "invite_room_state" + ) + raw_stripped_state_events = invite_room_state + elif membership == Membership.KNOCK: + knock_room_state = invite_or_knock_event.unsigned.get( + "knock_room_state" + ) + raw_stripped_state_events = knock_room_state + else: + raise AssertionError( + f"Unexpected membership {membership} (this is a problem with Synapse itself)" + ) + + stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None + # Scrutinize unsigned things. `raw_stripped_state_events` should be a list + # of stripped events + if raw_stripped_state_events is not None: + stripped_state_map = {} + if isinstance(raw_stripped_state_events, list): + for raw_stripped_event in raw_stripped_state_events: + stripped_state_event = parse_stripped_state_event( + raw_stripped_event + ) + if stripped_state_event is not None: + stripped_state_map[ + ( + stripped_state_event.type, + stripped_state_event.state_key, + ) + ] = stripped_state_event + + room_id_to_stripped_state_map[room_id] = stripped_state_map + + return room_id_to_stripped_state_map + + async def _bulk_get_partial_current_state_content_for_rooms( + self, + content_type: Literal[ + # `content.type` from `EventTypes.Create`` + "room_type", + # `content.algorithm` from `EventTypes.RoomEncryption` + "room_encryption", + ], + room_ids: Set[str], + sync_room_map: Dict[str, _RoomMembershipForUser], + to_token: StreamToken, + room_id_to_stripped_state_map: Dict[ + str, Optional[StateMap[StrippedStateEvent]] + ], + ) -> Mapping[str, Union[Optional[str], StateSentinel]]: + """ + Get the given state event content for a list of rooms. First we check the + current state of the room, then fallback to stripped state if available, then + historical state. + + Args: + content_type: Which content to grab + room_ids: Room IDs to fetch the given content field for. + sync_room_map: Dictionary of room IDs to sort along with membership + information in the room at the time of `to_token`. + to_token: We filter based on the state of the room at this token + room_id_to_stripped_state_map: This does not need to be filled in before + calling this function. Mapping from room_id to mapping of (type, state_key) + to stripped state event. Modified in place when we fetch new rooms so we can + save work next time this function is called. + + Returns: + A mapping from room ID to the state event content if the room has + the given state event (event_type, ""), otherwise `None`. Rooms unknown to + this server will return `ROOM_UNKNOWN_SENTINEL`. + """ + room_id_to_content: Dict[str, Union[Optional[str], StateSentinel]] = {} + + # As a bulk shortcut, use the current state if the server is particpating in the + # room (meaning we have current state). Ideally, for leave/ban rooms, we would + # want the state at the time of the membership instead of current state to not + # leak anything but we consider the create/encryption stripped state events to + # not be a secret given they are often set at the start of the room and they are + # normally handed out on invite/knock. + # + # Be mindful to only use this for non-sensitive details. For example, even + # though the room name/avatar/topic are also stripped state, they seem a lot + # more senstive to leak the current state value of. + # + # Since this function is cached, we need to make a mutable copy via + # `dict(...)`. + event_type = "" + event_content_field = "" + if content_type == "room_type": + event_type = EventTypes.Create + event_content_field = EventContentFields.ROOM_TYPE + room_id_to_content = dict(await self.store.bulk_get_room_type(room_ids)) + elif content_type == "room_encryption": + event_type = EventTypes.RoomEncryption + event_content_field = EventContentFields.ENCRYPTION_ALGORITHM + room_id_to_content = dict( + await self.store.bulk_get_room_encryption(room_ids) + ) + else: + assert_never(content_type) + + room_ids_with_results = [ + room_id + for room_id, content_field in room_id_to_content.items() + if content_field is not ROOM_UNKNOWN_SENTINEL + ] + + # We might not have current room state for remote invite/knocks if we are + # the first person on our server to see the room. The best we can do is look + # in the optional stripped state from the invite/knock event. + room_ids_without_results = room_ids.difference( + chain( + room_ids_with_results, + [ + room_id + for room_id, stripped_state_map in room_id_to_stripped_state_map.items() + if stripped_state_map is not None + ], + ) + ) + room_id_to_stripped_state_map.update( + await self._bulk_get_stripped_state_for_rooms_from_sync_room_map( + room_ids_without_results, sync_room_map + ) + ) + + # Update our `room_id_to_content` map based on the stripped state + # (applies to invite/knock rooms) + rooms_ids_without_stripped_state: Set[str] = set() + for room_id in room_ids_without_results: + stripped_state_map = room_id_to_stripped_state_map.get( + room_id, Sentinel.UNSET_SENTINEL + ) + assert stripped_state_map is not Sentinel.UNSET_SENTINEL, ( + f"Stripped state left unset for room {room_id}. " + + "Make sure you're calling `_bulk_get_stripped_state_for_rooms_from_sync_room_map(...)` " + + "with that room_id. (this is a problem with Synapse itself)" + ) + + # If there is some stripped state, we assume the remote server passed *all* + # of the potential stripped state events for the room. + if stripped_state_map is not None: + create_stripped_event = stripped_state_map.get((EventTypes.Create, "")) + stripped_event = stripped_state_map.get((event_type, "")) + # Sanity check that we at-least have the create event + if create_stripped_event is not None: + if stripped_event is not None: + room_id_to_content[room_id] = stripped_event.content.get( + event_content_field + ) + else: + # Didn't see the state event we're looking for in the stripped + # state so we can assume relevant content field is `None`. + room_id_to_content[room_id] = None + else: + rooms_ids_without_stripped_state.add(room_id) + + # Last resort, we might not have current room state for rooms that the + # server has left (no one local is in the room) but we can look at the + # historical state. + # + # Update our `room_id_to_content` map based on the state at the time of + # the membership event. + for room_id in rooms_ids_without_stripped_state: + # TODO: It would be nice to look this up in a bulk way (N+1 queries) + # + # TODO: `get_state_at(...)` doesn't take into account the "current state". + room_state = await self.storage_controllers.state.get_state_at( + room_id=room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + sync_room_map[room_id].event_pos.to_room_stream_token(), + ), + state_filter=StateFilter.from_types( + [ + (EventTypes.Create, ""), + (event_type, ""), + ] + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want the create event and some non-member event. + await_full_state=False, + ) + # We can use the create event as a canary to tell whether the server has + # seen the room before + create_event = room_state.get((EventTypes.Create, "")) + state_event = room_state.get((event_type, "")) + + if create_event is None: + # Skip for unknown rooms + continue + + if state_event is not None: + room_id_to_content[room_id] = state_event.content.get( + event_content_field + ) + else: + # Didn't see the state event we're looking for in the stripped + # state so we can assume relevant content field is `None`. + room_id_to_content[room_id] = None + + return room_id_to_content + @trace async def filter_rooms( self, @@ -1194,6 +1473,10 @@ class SlidingSyncHandler: A filtered dictionary of room IDs along with membership information in the room at the time of `to_token`. """ + room_id_to_stripped_state_map: Dict[ + str, Optional[StateMap[StrippedStateEvent]] + ] = {} + filtered_room_id_set = set(sync_room_map.keys()) # Filter for Direct-Message (DM) rooms @@ -1213,31 +1496,34 @@ class SlidingSyncHandler: if not sync_room_map[room_id].is_dm } - if filters.spaces: + if filters.spaces is not None: raise NotImplementedError() # Filter for encrypted rooms if filters.is_encrypted is not None: + room_id_to_encryption = ( + await self._bulk_get_partial_current_state_content_for_rooms( + content_type="room_encryption", + room_ids=filtered_room_id_set, + to_token=to_token, + sync_room_map=sync_room_map, + room_id_to_stripped_state_map=room_id_to_stripped_state_map, + ) + ) + # Make a copy so we don't run into an error: `Set changed size during # iteration`, when we filter out and remove items for room_id in filtered_room_id_set.copy(): - state_at_to_token = await self.storage_controllers.state.get_state_at( - room_id, - to_token, - state_filter=StateFilter.from_types( - [(EventTypes.RoomEncryption, "")] - ), - # Partially-stated rooms should have all state events except for the - # membership events so we don't need to wait because we only care - # about retrieving the `EventTypes.RoomEncryption` state event here. - # Plus we don't want to block the whole sync waiting for this one - # room. - await_full_state=False, - ) - is_encrypted = state_at_to_token.get((EventTypes.RoomEncryption, "")) + encryption = room_id_to_encryption.get(room_id, ROOM_UNKNOWN_SENTINEL) + + # Just remove rooms if we can't determine their encryption status + if encryption is ROOM_UNKNOWN_SENTINEL: + filtered_room_id_set.remove(room_id) + continue # If we're looking for encrypted rooms, filter out rooms that are not # encrypted and vice versa + is_encrypted = encryption is not None if (filters.is_encrypted and not is_encrypted) or ( not filters.is_encrypted and is_encrypted ): @@ -1263,15 +1549,26 @@ class SlidingSyncHandler: # provided in the list. `None` is a valid type for rooms which do not have a # room type. if filters.room_types is not None or filters.not_room_types is not None: - room_to_type = await self.store.bulk_get_room_type( - { - room_id - for room_id in filtered_room_id_set - # We only know the room types for joined rooms - if sync_room_map[room_id].membership == Membership.JOIN - } + room_id_to_type = ( + await self._bulk_get_partial_current_state_content_for_rooms( + content_type="room_type", + room_ids=filtered_room_id_set, + to_token=to_token, + sync_room_map=sync_room_map, + room_id_to_stripped_state_map=room_id_to_stripped_state_map, + ) ) - for room_id, room_type in room_to_type.items(): + + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in filtered_room_id_set.copy(): + room_type = room_id_to_type.get(room_id, ROOM_UNKNOWN_SENTINEL) + + # Just remove rooms if we can't determine their type + if room_type is ROOM_UNKNOWN_SENTINEL: + filtered_room_id_set.remove(room_id) + continue + if ( filters.room_types is not None and room_type not in filters.room_types @@ -1284,13 +1581,24 @@ class SlidingSyncHandler: ): filtered_room_id_set.remove(room_id) - if filters.room_name_like: + if filters.room_name_like is not None: + # TODO: The room name is a bit more sensitive to leak than the + # create/encryption event. Maybe we should consider a better way to fetch + # historical state before implementing this. + # + # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms( + # content_type="room_name", + # room_ids=filtered_room_id_set, + # to_token=to_token, + # sync_room_map=sync_room_map, + # room_id_to_stripped_state_map=room_id_to_stripped_state_map, + # ) raise NotImplementedError() - if filters.tags: + if filters.tags is not None: raise NotImplementedError() - if filters.not_tags: + if filters.not_tags is not None: raise NotImplementedError() # Assemble a new sync room map but only with the `filtered_room_id_set` @@ -1371,14 +1679,17 @@ class SlidingSyncHandler: in the room at the time of `to_token`. to_token: The point in the stream to sync up to. """ - room_state_ids: StateMap[str] + state_ids: StateMap[str] # People shouldn't see past their leave/ban event if room_membership_for_user_at_to_token.membership in ( Membership.LEAVE, Membership.BAN, ): - # TODO: `get_state_ids_at(...)` doesn't take into account the "current state" - room_state_ids = await self.storage_controllers.state.get_state_ids_at( + # TODO: `get_state_ids_at(...)` doesn't take into account the "current + # state". Maybe we need to use + # `get_forward_extremities_for_room_at_stream_ordering(...)` to "Fetch the + # current state at the time." + state_ids = await self.storage_controllers.state.get_state_ids_at( room_id, stream_position=to_token.copy_and_replace( StreamKeyType.ROOM, @@ -1397,7 +1708,7 @@ class SlidingSyncHandler: ) # Otherwise, we can get the latest current state in the room else: - room_state_ids = await self.storage_controllers.state.get_current_state_ids( + state_ids = await self.storage_controllers.state.get_current_state_ids( room_id, state_filter, # Partially-stated rooms should have all state events except for @@ -1412,7 +1723,7 @@ class SlidingSyncHandler: ) # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` - return room_state_ids + return state_ids async def get_current_state_at( self, @@ -1432,17 +1743,17 @@ class SlidingSyncHandler: in the room at the time of `to_token`. to_token: The point in the stream to sync up to. """ - room_state_ids = await self.get_current_state_ids_at( + state_ids = await self.get_current_state_ids_at( room_id=room_id, room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, state_filter=state_filter, to_token=to_token, ) - event_map = await self.store.get_events(list(room_state_ids.values())) + event_map = await self.store.get_events(list(state_ids.values())) state_map = {} - for key, event_id in room_state_ids.items(): + for key, event_id in state_ids.items(): event = event_map.get(event_id) if event: state_map[key] = event diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 1c94f3ca46..8f90c17060 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -293,7 +293,9 @@ class StatsHandler: "history_visibility" ) elif delta.event_type == EventTypes.RoomEncryption: - room_state["encryption"] = event_content.get("algorithm") + room_state["encryption"] = event_content.get( + EventContentFields.ENCRYPTION_ALGORITHM + ) elif delta.event_type == EventTypes.Name: room_state["name"] = event_content.get("name") elif delta.event_type == EventTypes.Topic: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 066f3d08ae..e12ab94576 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -127,6 +127,8 @@ class SQLBaseStore(metaclass=ABCMeta): # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) + self._attempt_to_invalidate_cache("get_room_type", (room_id,)) + self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) def _invalidate_state_caches_all(self, room_id: str) -> None: """Invalidates caches that are based on the current state, but does @@ -153,6 +155,8 @@ class SQLBaseStore(metaclass=ABCMeta): "_get_rooms_for_local_user_where_membership_is_inner", None ) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) + self._attempt_to_invalidate_cache("get_room_type", (room_id,)) + self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 26b8e1a172..63624f3e8f 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -268,13 +268,23 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined] if data.type == EventTypes.Member: - self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined] + self._attempt_to_invalidate_cache( + "get_rooms_for_user", (data.state_key,) + ) + elif data.type == EventTypes.RoomEncryption: + self._attempt_to_invalidate_cache( + "get_room_encryption", (data.room_id,) + ) + elif data.type == EventTypes.Create: + self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) elif row.type == EventsStreamAllStateRow.TypeId: assert isinstance(data, EventsStreamAllStateRow) # Similar to the above, but the entire caches are invalidated. This is # unfortunate for the membership caches, but should recover quickly. self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined] - self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined] + self._attempt_to_invalidate_cache("get_rooms_for_user", None) + self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) + self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) @@ -345,6 +355,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_forgotten_rooms_for_user", (state_key,) ) + elif etype == EventTypes.Create: + self._attempt_to_invalidate_cache("get_room_type", (room_id,)) + elif etype == EventTypes.RoomEncryption: + self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) if relates_to: self._attempt_to_invalidate_cache( @@ -405,6 +419,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_thread_summary", None) self._attempt_to_invalidate_cache("get_thread_participated", None) self._attempt_to_invalidate_cache("get_threads", (room_id,)) + self._attempt_to_invalidate_cache("get_room_type", (room_id,)) + self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) self._attempt_to_invalidate_cache("_get_state_group_for_event", None) @@ -457,6 +473,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) self._attempt_to_invalidate_cache("get_room_version_id", (room_id,)) + self._attempt_to_invalidate_cache("get_room_type", (room_id,)) + self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) # And delete state caches. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 5188b2f7a4..62bc4600fb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -30,6 +30,7 @@ from typing import ( Iterable, List, Mapping, + MutableMapping, Optional, Set, Tuple, @@ -72,10 +73,18 @@ logger = logging.getLogger(__name__) _T = TypeVar("_T") - MAX_STATE_DELTA_HOPS = 100 +# Freeze so it's immutable and we can use it as a cache value +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Sentinel: + pass + + +ROOM_UNKNOWN_SENTINEL = Sentinel() + + @attr.s(slots=True, frozen=True, auto_attribs=True) class EventMetadata: """Returned by `get_metadata_for_events`""" @@ -300,51 +309,189 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @cached(max_entries=10000) async def get_room_type(self, room_id: str) -> Optional[str]: - """Get the room type for a given room. The server must be joined to the - given room. - """ - - row = await self.db_pool.simple_select_one( - table="room_stats_state", - keyvalues={"room_id": room_id}, - retcols=("room_type",), - allow_none=True, - desc="get_room_type", - ) - - if row is not None: - return row[0] - - # If we haven't updated `room_stats_state` with the room yet, query the - # create event directly. - create_event = await self.get_create_event_for_room(room_id) - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) - return room_type + raise NotImplementedError() @cachedList(cached_method_name="get_room_type", list_name="room_ids") async def bulk_get_room_type( self, room_ids: Set[str] - ) -> Mapping[str, Optional[str]]: - """Bulk fetch room types for the given rooms, the server must be in all - the rooms given. + ) -> Mapping[str, Union[Optional[str], Sentinel]]: """ + Bulk fetch room types for the given rooms (via current state). - rows = await self.db_pool.simple_select_many_batch( - table="room_stats_state", - column="room_id", - iterable=room_ids, - retcols=("room_id", "room_type"), - desc="bulk_get_room_type", + Since this function is cached, any missing values would be cached as `None`. In + order to distinguish between an unencrypted room that has `None` encryption and + a room that is unknown to the server where we might want to omit the value + (which would make it cached as `None`), instead we use the sentinel value + `ROOM_UNKNOWN_SENTINEL`. + + Returns: + A mapping from room ID to the room's type (`None` is a valid room type). + Rooms unknown to this server will return `ROOM_UNKNOWN_SENTINEL`. + """ + + def txn( + txn: LoggingTransaction, + ) -> MutableMapping[str, Union[Optional[str], Sentinel]]: + clause, args = make_in_list_sql_clause( + txn.database_engine, "room_id", room_ids + ) + + # We can't rely on `room_stats_state.room_type` if the server has left the + # room because the `room_id` will still be in the table but everything will + # be set to `None` but `None` is a valid room type value. We join against + # the `room_stats_current` table which keeps track of the + # `current_state_events` count (and a proxy value `local_users_in_room` + # which can used to assume the server is participating in the room and has + # current state) to ensure that the data in `room_stats_state` is up-to-date + # with the current state. + # + # FIXME: Use `room_stats_current.current_state_events` instead of + # `room_stats_current.local_users_in_room` once + # https://github.com/element-hq/synapse/issues/17457 is fixed. + sql = f""" + SELECT room_id, room_type + FROM room_stats_state + INNER JOIN room_stats_current USING (room_id) + WHERE + {clause} + AND local_users_in_room > 0 + """ + + txn.execute(sql, args) + + room_id_to_type_map = {} + for row in txn: + room_id_to_type_map[row[0]] = row[1] + + return room_id_to_type_map + + results = await self.db_pool.runInteraction( + "bulk_get_room_type", + txn, ) # If we haven't updated `room_stats_state` with the room yet, query the # create events directly. This should happen only rarely so we don't # mind if we do this in a loop. - results = dict(rows) for room_id in room_ids - results.keys(): - create_event = await self.get_create_event_for_room(room_id) - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) - results[room_id] = room_type + try: + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + results[room_id] = room_type + except NotFoundError: + # We use the sentinel value to distinguish between `None` which is a + # valid room type and a room that is unknown to the server so the value + # is just unset. + results[room_id] = ROOM_UNKNOWN_SENTINEL + + return results + + @cached(max_entries=10000) + async def get_room_encryption(self, room_id: str) -> Optional[str]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_room_encryption", list_name="room_ids") + async def bulk_get_room_encryption( + self, room_ids: Set[str] + ) -> Mapping[str, Union[Optional[str], Sentinel]]: + """ + Bulk fetch room encryption for the given rooms (via current state). + + Since this function is cached, any missing values would be cached as `None`. In + order to distinguish between an unencrypted room that has `None` encryption and + a room that is unknown to the server where we might want to omit the value + (which would make it cached as `None`), instead we use the sentinel value + `ROOM_UNKNOWN_SENTINEL`. + + Returns: + A mapping from room ID to the room's encryption algorithm if the room is + encrypted, otherwise `None`. Rooms unknown to this server will return + `ROOM_UNKNOWN_SENTINEL`. + """ + + def txn( + txn: LoggingTransaction, + ) -> MutableMapping[str, Union[Optional[str], Sentinel]]: + clause, args = make_in_list_sql_clause( + txn.database_engine, "room_id", room_ids + ) + + # We can't rely on `room_stats_state.encryption` if the server has left the + # room because the `room_id` will still be in the table but everything will + # be set to `None` but `None` is a valid encryption value. We join against + # the `room_stats_current` table which keeps track of the + # `current_state_events` count (and a proxy value `local_users_in_room` + # which can used to assume the server is participating in the room and has + # current state) to ensure that the data in `room_stats_state` is up-to-date + # with the current state. + # + # FIXME: Use `room_stats_current.current_state_events` instead of + # `room_stats_current.local_users_in_room` once + # https://github.com/element-hq/synapse/issues/17457 is fixed. + sql = f""" + SELECT room_id, encryption + FROM room_stats_state + INNER JOIN room_stats_current USING (room_id) + WHERE + {clause} + AND local_users_in_room > 0 + """ + + txn.execute(sql, args) + + room_id_to_encryption_map = {} + for row in txn: + room_id_to_encryption_map[row[0]] = row[1] + + return room_id_to_encryption_map + + results = await self.db_pool.runInteraction( + "bulk_get_room_encryption", + txn, + ) + + # If we haven't updated `room_stats_state` with the room yet, query the state + # directly. This should happen only rarely so we don't mind if we do this in a + # loop. + encryption_event_ids: List[str] = [] + for room_id in room_ids - results.keys(): + state_map = await self.get_partial_filtered_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + [ + (EventTypes.Create, ""), + (EventTypes.RoomEncryption, ""), + ] + ), + ) + # We can use the create event as a canary to tell whether the server has + # seen the room before + create_event_id = state_map.get((EventTypes.Create, "")) + encryption_event_id = state_map.get((EventTypes.RoomEncryption, "")) + + if create_event_id is None: + # We use the sentinel value to distinguish between `None` which is a + # valid room type and a room that is unknown to the server so the value + # is just unset. + results[room_id] = ROOM_UNKNOWN_SENTINEL + continue + + if encryption_event_id is None: + results[room_id] = None + else: + encryption_event_ids.append(encryption_event_id) + + encryption_event_map = await self.get_events(encryption_event_ids) + + for encryption_event_id in encryption_event_ids: + encryption_event = encryption_event_map.get(encryption_event_id) + # If the curent state says there is an encryption event, we should have it + # in the database. + assert encryption_event is not None + + results[encryption_event.room_id] = encryption_event.content.get( + EventContentFields.ENCRYPTION_ALGORITHM + ) return results |