summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-30 13:20:29 -0500
committerGitHub <noreply@github.com>2024-07-30 13:20:29 -0500
commit46de0ee16be8731f0ed68654edc75aced1510b19 (patch)
treeb1538d66ed8ed2651c7db671fe76231542aa33d9 /synapse
parentSliding Sync: Add receipts extension (MSC3960) (#17489) (diff)
downloadsynapse-46de0ee16be8731f0ed68654edc75aced1510b19.tar.xz
Sliding Sync: Update filters to be robust against remote invite rooms (#17450)
Update `filters.is_encrypted` and `filters.types`/`filters.not_types` to
be robust when dealing with remote invite rooms in Sliding Sync.

Part of
[MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575):
Sliding Sync

Follow-up to https://github.com/element-hq/synapse/pull/17434

We now take into account current state, fallback to stripped state
for invite/knock rooms, then historical state. If we can't determine
the info needed to filter a room (either from state or stripped state),
it is filtered out.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py8
-rw-r--r--synapse/events/__init__.py19
-rw-r--r--synapse/events/utils.py29
-rw-r--r--synapse/handlers/sliding_sync.py385
-rw-r--r--synapse/handlers/stats.py4
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/databases/main/cache.py22
-rw-r--r--synapse/storage/databases/main/state.py215
8 files changed, 611 insertions, 75 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 85001d9676..7dcb1e01fd 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -225,6 +225,11 @@ class EventContentFields:
     # This is deprecated in MSC2175.
     ROOM_CREATOR: Final = "creator"
 
+    # The version of the room for `m.room.create` events.
+    ROOM_VERSION: Final = "room_version"
+
+    ROOM_NAME: Final = "name"
+
     # Used in m.room.guest_access events.
     GUEST_ACCESS: Final = "guest_access"
 
@@ -237,6 +242,9 @@ class EventContentFields:
     # an unspecced field added to to-device messages to identify them uniquely-ish
     TO_DEVICE_MSGID: Final = "org.matrix.msgid"
 
+    # `m.room.encryption`` algorithm field
+    ENCRYPTION_ALGORITHM: Final = "algorithm"
+
 
 class EventUnsignedContentFields:
     """Fields found inside the 'unsigned' data on events"""
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 36e0f47e51..2e56b671f0 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -554,3 +554,22 @@ def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
             aggregation_key = None
 
     return _EventRelation(parent_id, rel_type, aggregation_key)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StrippedStateEvent:
+    """
+    A stripped down state event. Usually used for remote invite/knocks so the user can
+    make an informed decision on whether they want to join.
+
+    Attributes:
+        type: Event `type`
+        state_key: Event `state_key`
+        sender: Event `sender`
+        content: Event `content`
+    """
+
+    type: str
+    state_key: str
+    sender: str
+    content: Dict[str, Any]
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f937fd4698..54f94add4d 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -49,7 +49,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.types import JsonDict, Requester
 
-from . import EventBase, make_event_from_dict
+from . import EventBase, StrippedStateEvent, make_event_from_dict
 
 if TYPE_CHECKING:
     from synapse.handlers.relations import BundledAggregations
@@ -854,3 +854,30 @@ def strip_event(event: EventBase) -> JsonDict:
         "content": event.content,
         "sender": event.sender,
     }
+
+
+def parse_stripped_state_event(raw_stripped_event: Any) -> Optional[StrippedStateEvent]:
+    """
+    Given a raw value from an event's `unsigned` field, attempt to parse it into a
+    `StrippedStateEvent`.
+    """
+    if isinstance(raw_stripped_event, dict):
+        # All of these fields are required
+        type = raw_stripped_event.get("type")
+        state_key = raw_stripped_event.get("state_key")
+        sender = raw_stripped_event.get("sender")
+        content = raw_stripped_event.get("content")
+        if (
+            isinstance(type, str)
+            and isinstance(state_key, str)
+            and isinstance(sender, str)
+            and isinstance(content, dict)
+        ):
+            return StrippedStateEvent(
+                type=type,
+                state_key=state_key,
+                sender=sender,
+                content=content,
+            )
+
+    return None
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 7a734f6712..530e7b7b4e 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -17,6 +17,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
+import enum
 import logging
 from enum import Enum
 from itertools import chain
@@ -26,23 +27,35 @@ from typing import (
     Dict,
     Final,
     List,
+    Literal,
     Mapping,
     Optional,
     Sequence,
     Set,
     Tuple,
+    Union,
 )
 
 import attr
 from immutabledict import immutabledict
 from typing_extensions import assert_never
 
-from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership
-from synapse.events import EventBase
-from synapse.events.utils import strip_event
+from synapse.api.constants import (
+    AccountDataTypes,
+    Direction,
+    EventContentFields,
+    EventTypes,
+    Membership,
+)
+from synapse.events import EventBase, StrippedStateEvent
+from synapse.events.utils import parse_stripped_state_event, strip_event
 from synapse.handlers.relations import BundledAggregations
 from synapse.logging.opentracing import log_kv, start_active_span, tag_args, trace
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
+from synapse.storage.databases.main.state import (
+    ROOM_UNKNOWN_SENTINEL,
+    Sentinel as StateSentinel,
+)
 from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
 from synapse.storage.roommember import MemberSummary
 from synapse.types import (
@@ -50,6 +63,7 @@ from synapse.types import (
     JsonDict,
     JsonMapping,
     MultiWriterStreamToken,
+    MutableStateMap,
     PersistedEventPosition,
     Requester,
     RoomStreamToken,
@@ -71,6 +85,12 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup and subsequent type narrowing.
+    UNSET_SENTINEL = object()
+
+
 # The event types that clients should consider as new activity.
 DEFAULT_BUMP_EVENT_TYPES = {
     EventTypes.Create,
@@ -1172,6 +1192,265 @@ class SlidingSyncHandler:
 
         # return None
 
+    async def _bulk_get_stripped_state_for_rooms_from_sync_room_map(
+        self,
+        room_ids: StrCollection,
+        sync_room_map: Dict[str, _RoomMembershipForUser],
+    ) -> Dict[str, Optional[StateMap[StrippedStateEvent]]]:
+        """
+        Fetch stripped state for a list of room IDs. Stripped state is only
+        applicable to invite/knock rooms. Other rooms will have `None` as their
+        stripped state.
+
+        For invite rooms, we pull from `unsigned.invite_room_state`.
+        For knock rooms, we pull from `unsigned.knock_room_state`.
+
+        Args:
+            room_ids: Room IDs to fetch stripped state for
+            sync_room_map: Dictionary of room IDs to sort along with membership
+                information in the room at the time of `to_token`.
+
+        Returns:
+            Mapping from room_id to mapping of (type, state_key) to stripped state
+            event.
+        """
+        room_id_to_stripped_state_map: Dict[
+            str, Optional[StateMap[StrippedStateEvent]]
+        ] = {}
+
+        # Fetch what we haven't before
+        room_ids_to_fetch = [
+            room_id
+            for room_id in room_ids
+            if room_id not in room_id_to_stripped_state_map
+        ]
+
+        # Gather a list of event IDs we can grab stripped state from
+        invite_or_knock_event_ids: List[str] = []
+        for room_id in room_ids_to_fetch:
+            if sync_room_map[room_id].membership in (
+                Membership.INVITE,
+                Membership.KNOCK,
+            ):
+                event_id = sync_room_map[room_id].event_id
+                # If this is an invite/knock then there should be an event_id
+                assert event_id is not None
+                invite_or_knock_event_ids.append(event_id)
+            else:
+                room_id_to_stripped_state_map[room_id] = None
+
+        invite_or_knock_events = await self.store.get_events(invite_or_knock_event_ids)
+        for invite_or_knock_event in invite_or_knock_events.values():
+            room_id = invite_or_knock_event.room_id
+            membership = invite_or_knock_event.membership
+
+            raw_stripped_state_events = None
+            if membership == Membership.INVITE:
+                invite_room_state = invite_or_knock_event.unsigned.get(
+                    "invite_room_state"
+                )
+                raw_stripped_state_events = invite_room_state
+            elif membership == Membership.KNOCK:
+                knock_room_state = invite_or_knock_event.unsigned.get(
+                    "knock_room_state"
+                )
+                raw_stripped_state_events = knock_room_state
+            else:
+                raise AssertionError(
+                    f"Unexpected membership {membership} (this is a problem with Synapse itself)"
+                )
+
+            stripped_state_map: Optional[MutableStateMap[StrippedStateEvent]] = None
+            # Scrutinize unsigned things. `raw_stripped_state_events` should be a list
+            # of stripped events
+            if raw_stripped_state_events is not None:
+                stripped_state_map = {}
+                if isinstance(raw_stripped_state_events, list):
+                    for raw_stripped_event in raw_stripped_state_events:
+                        stripped_state_event = parse_stripped_state_event(
+                            raw_stripped_event
+                        )
+                        if stripped_state_event is not None:
+                            stripped_state_map[
+                                (
+                                    stripped_state_event.type,
+                                    stripped_state_event.state_key,
+                                )
+                            ] = stripped_state_event
+
+            room_id_to_stripped_state_map[room_id] = stripped_state_map
+
+        return room_id_to_stripped_state_map
+
+    async def _bulk_get_partial_current_state_content_for_rooms(
+        self,
+        content_type: Literal[
+            # `content.type` from `EventTypes.Create``
+            "room_type",
+            # `content.algorithm` from `EventTypes.RoomEncryption`
+            "room_encryption",
+        ],
+        room_ids: Set[str],
+        sync_room_map: Dict[str, _RoomMembershipForUser],
+        to_token: StreamToken,
+        room_id_to_stripped_state_map: Dict[
+            str, Optional[StateMap[StrippedStateEvent]]
+        ],
+    ) -> Mapping[str, Union[Optional[str], StateSentinel]]:
+        """
+        Get the given state event content for a list of rooms. First we check the
+        current state of the room, then fallback to stripped state if available, then
+        historical state.
+
+        Args:
+            content_type: Which content to grab
+            room_ids: Room IDs to fetch the given content field for.
+            sync_room_map: Dictionary of room IDs to sort along with membership
+                information in the room at the time of `to_token`.
+            to_token: We filter based on the state of the room at this token
+            room_id_to_stripped_state_map: This does not need to be filled in before
+                calling this function. Mapping from room_id to mapping of (type, state_key)
+                to stripped state event. Modified in place when we fetch new rooms so we can
+                save work next time this function is called.
+
+        Returns:
+            A mapping from room ID to the state event content if the room has
+            the given state event (event_type, ""), otherwise `None`. Rooms unknown to
+            this server will return `ROOM_UNKNOWN_SENTINEL`.
+        """
+        room_id_to_content: Dict[str, Union[Optional[str], StateSentinel]] = {}
+
+        # As a bulk shortcut, use the current state if the server is particpating in the
+        # room (meaning we have current state). Ideally, for leave/ban rooms, we would
+        # want the state at the time of the membership instead of current state to not
+        # leak anything but we consider the create/encryption stripped state events to
+        # not be a secret given they are often set at the start of the room and they are
+        # normally handed out on invite/knock.
+        #
+        # Be mindful to only use this for non-sensitive details. For example, even
+        # though the room name/avatar/topic are also stripped state, they seem a lot
+        # more senstive to leak the current state value of.
+        #
+        # Since this function is cached, we need to make a mutable copy via
+        # `dict(...)`.
+        event_type = ""
+        event_content_field = ""
+        if content_type == "room_type":
+            event_type = EventTypes.Create
+            event_content_field = EventContentFields.ROOM_TYPE
+            room_id_to_content = dict(await self.store.bulk_get_room_type(room_ids))
+        elif content_type == "room_encryption":
+            event_type = EventTypes.RoomEncryption
+            event_content_field = EventContentFields.ENCRYPTION_ALGORITHM
+            room_id_to_content = dict(
+                await self.store.bulk_get_room_encryption(room_ids)
+            )
+        else:
+            assert_never(content_type)
+
+        room_ids_with_results = [
+            room_id
+            for room_id, content_field in room_id_to_content.items()
+            if content_field is not ROOM_UNKNOWN_SENTINEL
+        ]
+
+        # We might not have current room state for remote invite/knocks if we are
+        # the first person on our server to see the room. The best we can do is look
+        # in the optional stripped state from the invite/knock event.
+        room_ids_without_results = room_ids.difference(
+            chain(
+                room_ids_with_results,
+                [
+                    room_id
+                    for room_id, stripped_state_map in room_id_to_stripped_state_map.items()
+                    if stripped_state_map is not None
+                ],
+            )
+        )
+        room_id_to_stripped_state_map.update(
+            await self._bulk_get_stripped_state_for_rooms_from_sync_room_map(
+                room_ids_without_results, sync_room_map
+            )
+        )
+
+        # Update our `room_id_to_content` map based on the stripped state
+        # (applies to invite/knock rooms)
+        rooms_ids_without_stripped_state: Set[str] = set()
+        for room_id in room_ids_without_results:
+            stripped_state_map = room_id_to_stripped_state_map.get(
+                room_id, Sentinel.UNSET_SENTINEL
+            )
+            assert stripped_state_map is not Sentinel.UNSET_SENTINEL, (
+                f"Stripped state left unset for room {room_id}. "
+                + "Make sure you're calling `_bulk_get_stripped_state_for_rooms_from_sync_room_map(...)` "
+                + "with that room_id. (this is a problem with Synapse itself)"
+            )
+
+            # If there is some stripped state, we assume the remote server passed *all*
+            # of the potential stripped state events for the room.
+            if stripped_state_map is not None:
+                create_stripped_event = stripped_state_map.get((EventTypes.Create, ""))
+                stripped_event = stripped_state_map.get((event_type, ""))
+                # Sanity check that we at-least have the create event
+                if create_stripped_event is not None:
+                    if stripped_event is not None:
+                        room_id_to_content[room_id] = stripped_event.content.get(
+                            event_content_field
+                        )
+                    else:
+                        # Didn't see the state event we're looking for in the stripped
+                        # state so we can assume relevant content field is `None`.
+                        room_id_to_content[room_id] = None
+            else:
+                rooms_ids_without_stripped_state.add(room_id)
+
+        # Last resort, we might not have current room state for rooms that the
+        # server has left (no one local is in the room) but we can look at the
+        # historical state.
+        #
+        # Update our `room_id_to_content` map based on the state at the time of
+        # the membership event.
+        for room_id in rooms_ids_without_stripped_state:
+            # TODO: It would be nice to look this up in a bulk way (N+1 queries)
+            #
+            # TODO: `get_state_at(...)` doesn't take into account the "current state".
+            room_state = await self.storage_controllers.state.get_state_at(
+                room_id=room_id,
+                stream_position=to_token.copy_and_replace(
+                    StreamKeyType.ROOM,
+                    sync_room_map[room_id].event_pos.to_room_stream_token(),
+                ),
+                state_filter=StateFilter.from_types(
+                    [
+                        (EventTypes.Create, ""),
+                        (event_type, ""),
+                    ]
+                ),
+                # Partially-stated rooms should have all state events except for
+                # remote membership events so we don't need to wait at all because
+                # we only want the create event and some non-member event.
+                await_full_state=False,
+            )
+            # We can use the create event as a canary to tell whether the server has
+            # seen the room before
+            create_event = room_state.get((EventTypes.Create, ""))
+            state_event = room_state.get((event_type, ""))
+
+            if create_event is None:
+                # Skip for unknown rooms
+                continue
+
+            if state_event is not None:
+                room_id_to_content[room_id] = state_event.content.get(
+                    event_content_field
+                )
+            else:
+                # Didn't see the state event we're looking for in the stripped
+                # state so we can assume relevant content field is `None`.
+                room_id_to_content[room_id] = None
+
+        return room_id_to_content
+
     @trace
     async def filter_rooms(
         self,
@@ -1194,6 +1473,10 @@ class SlidingSyncHandler:
             A filtered dictionary of room IDs along with membership information in the
             room at the time of `to_token`.
         """
+        room_id_to_stripped_state_map: Dict[
+            str, Optional[StateMap[StrippedStateEvent]]
+        ] = {}
+
         filtered_room_id_set = set(sync_room_map.keys())
 
         # Filter for Direct-Message (DM) rooms
@@ -1213,31 +1496,34 @@ class SlidingSyncHandler:
                     if not sync_room_map[room_id].is_dm
                 }
 
-        if filters.spaces:
+        if filters.spaces is not None:
             raise NotImplementedError()
 
         # Filter for encrypted rooms
         if filters.is_encrypted is not None:
+            room_id_to_encryption = (
+                await self._bulk_get_partial_current_state_content_for_rooms(
+                    content_type="room_encryption",
+                    room_ids=filtered_room_id_set,
+                    to_token=to_token,
+                    sync_room_map=sync_room_map,
+                    room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+                )
+            )
+
             # Make a copy so we don't run into an error: `Set changed size during
             # iteration`, when we filter out and remove items
             for room_id in filtered_room_id_set.copy():
-                state_at_to_token = await self.storage_controllers.state.get_state_at(
-                    room_id,
-                    to_token,
-                    state_filter=StateFilter.from_types(
-                        [(EventTypes.RoomEncryption, "")]
-                    ),
-                    # Partially-stated rooms should have all state events except for the
-                    # membership events so we don't need to wait because we only care
-                    # about retrieving the `EventTypes.RoomEncryption` state event here.
-                    # Plus we don't want to block the whole sync waiting for this one
-                    # room.
-                    await_full_state=False,
-                )
-                is_encrypted = state_at_to_token.get((EventTypes.RoomEncryption, ""))
+                encryption = room_id_to_encryption.get(room_id, ROOM_UNKNOWN_SENTINEL)
+
+                # Just remove rooms if we can't determine their encryption status
+                if encryption is ROOM_UNKNOWN_SENTINEL:
+                    filtered_room_id_set.remove(room_id)
+                    continue
 
                 # If we're looking for encrypted rooms, filter out rooms that are not
                 # encrypted and vice versa
+                is_encrypted = encryption is not None
                 if (filters.is_encrypted and not is_encrypted) or (
                     not filters.is_encrypted and is_encrypted
                 ):
@@ -1263,15 +1549,26 @@ class SlidingSyncHandler:
         # provided in the list. `None` is a valid type for rooms which do not have a
         # room type.
         if filters.room_types is not None or filters.not_room_types is not None:
-            room_to_type = await self.store.bulk_get_room_type(
-                {
-                    room_id
-                    for room_id in filtered_room_id_set
-                    # We only know the room types for joined rooms
-                    if sync_room_map[room_id].membership == Membership.JOIN
-                }
+            room_id_to_type = (
+                await self._bulk_get_partial_current_state_content_for_rooms(
+                    content_type="room_type",
+                    room_ids=filtered_room_id_set,
+                    to_token=to_token,
+                    sync_room_map=sync_room_map,
+                    room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+                )
             )
-            for room_id, room_type in room_to_type.items():
+
+            # Make a copy so we don't run into an error: `Set changed size during
+            # iteration`, when we filter out and remove items
+            for room_id in filtered_room_id_set.copy():
+                room_type = room_id_to_type.get(room_id, ROOM_UNKNOWN_SENTINEL)
+
+                # Just remove rooms if we can't determine their type
+                if room_type is ROOM_UNKNOWN_SENTINEL:
+                    filtered_room_id_set.remove(room_id)
+                    continue
+
                 if (
                     filters.room_types is not None
                     and room_type not in filters.room_types
@@ -1284,13 +1581,24 @@ class SlidingSyncHandler:
                 ):
                     filtered_room_id_set.remove(room_id)
 
-        if filters.room_name_like:
+        if filters.room_name_like is not None:
+            # TODO: The room name is a bit more sensitive to leak than the
+            # create/encryption event. Maybe we should consider a better way to fetch
+            # historical state before implementing this.
+            #
+            # room_id_to_create_content = await self._bulk_get_partial_current_state_content_for_rooms(
+            #     content_type="room_name",
+            #     room_ids=filtered_room_id_set,
+            #     to_token=to_token,
+            #     sync_room_map=sync_room_map,
+            #     room_id_to_stripped_state_map=room_id_to_stripped_state_map,
+            # )
             raise NotImplementedError()
 
-        if filters.tags:
+        if filters.tags is not None:
             raise NotImplementedError()
 
-        if filters.not_tags:
+        if filters.not_tags is not None:
             raise NotImplementedError()
 
         # Assemble a new sync room map but only with the `filtered_room_id_set`
@@ -1371,14 +1679,17 @@ class SlidingSyncHandler:
                 in the room at the time of `to_token`.
             to_token: The point in the stream to sync up to.
         """
-        room_state_ids: StateMap[str]
+        state_ids: StateMap[str]
         # People shouldn't see past their leave/ban event
         if room_membership_for_user_at_to_token.membership in (
             Membership.LEAVE,
             Membership.BAN,
         ):
-            # TODO: `get_state_ids_at(...)` doesn't take into account the "current state"
-            room_state_ids = await self.storage_controllers.state.get_state_ids_at(
+            # TODO: `get_state_ids_at(...)` doesn't take into account the "current
+            # state". Maybe we need to use
+            # `get_forward_extremities_for_room_at_stream_ordering(...)` to "Fetch the
+            # current state at the time."
+            state_ids = await self.storage_controllers.state.get_state_ids_at(
                 room_id,
                 stream_position=to_token.copy_and_replace(
                     StreamKeyType.ROOM,
@@ -1397,7 +1708,7 @@ class SlidingSyncHandler:
             )
         # Otherwise, we can get the latest current state in the room
         else:
-            room_state_ids = await self.storage_controllers.state.get_current_state_ids(
+            state_ids = await self.storage_controllers.state.get_current_state_ids(
                 room_id,
                 state_filter,
                 # Partially-stated rooms should have all state events except for
@@ -1412,7 +1723,7 @@ class SlidingSyncHandler:
             )
             # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
 
-        return room_state_ids
+        return state_ids
 
     async def get_current_state_at(
         self,
@@ -1432,17 +1743,17 @@ class SlidingSyncHandler:
                 in the room at the time of `to_token`.
             to_token: The point in the stream to sync up to.
         """
-        room_state_ids = await self.get_current_state_ids_at(
+        state_ids = await self.get_current_state_ids_at(
             room_id=room_id,
             room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
             state_filter=state_filter,
             to_token=to_token,
         )
 
-        event_map = await self.store.get_events(list(room_state_ids.values()))
+        event_map = await self.store.get_events(list(state_ids.values()))
 
         state_map = {}
-        for key, event_id in room_state_ids.items():
+        for key, event_id in state_ids.items():
             event = event_map.get(event_id)
             if event:
                 state_map[key] = event
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 1c94f3ca46..8f90c17060 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -293,7 +293,9 @@ class StatsHandler:
                     "history_visibility"
                 )
             elif delta.event_type == EventTypes.RoomEncryption:
-                room_state["encryption"] = event_content.get("algorithm")
+                room_state["encryption"] = event_content.get(
+                    EventContentFields.ENCRYPTION_ALGORITHM
+                )
             elif delta.event_type == EventTypes.Name:
                 room_state["name"] = event_content.get("name")
             elif delta.event_type == EventTypes.Topic:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 066f3d08ae..e12ab94576 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,6 +127,8 @@ class SQLBaseStore(metaclass=ABCMeta):
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
         self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
     def _invalidate_state_caches_all(self, room_id: str) -> None:
         """Invalidates caches that are based on the current state, but does
@@ -153,6 +155,8 @@ class SQLBaseStore(metaclass=ABCMeta):
             "_get_rooms_for_local_user_where_membership_is_inner", None
         )
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 26b8e1a172..63624f3e8f 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -268,13 +268,23 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)  # type: ignore[attr-defined]
 
             if data.type == EventTypes.Member:
-                self.get_rooms_for_user.invalidate((data.state_key,))  # type: ignore[attr-defined]
+                self._attempt_to_invalidate_cache(
+                    "get_rooms_for_user", (data.state_key,)
+                )
+            elif data.type == EventTypes.RoomEncryption:
+                self._attempt_to_invalidate_cache(
+                    "get_room_encryption", (data.room_id,)
+                )
+            elif data.type == EventTypes.Create:
+                self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
         elif row.type == EventsStreamAllStateRow.TypeId:
             assert isinstance(data, EventsStreamAllStateRow)
             # Similar to the above, but the entire caches are invalidated. This is
             # unfortunate for the membership caches, but should recover quickly.
             self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)  # type: ignore[attr-defined]
-            self.get_rooms_for_user.invalidate_all()  # type: ignore[attr-defined]
+            self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+            self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
+            self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
         else:
             raise Exception("Unknown events stream row type %s" % (row.type,))
 
@@ -345,6 +355,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._attempt_to_invalidate_cache(
                 "get_forgotten_rooms_for_user", (state_key,)
             )
+        elif etype == EventTypes.Create:
+            self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        elif etype == EventTypes.RoomEncryption:
+            self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
         if relates_to:
             self._attempt_to_invalidate_cache(
@@ -405,6 +419,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_thread_summary", None)
         self._attempt_to_invalidate_cache("get_thread_participated", None)
         self._attempt_to_invalidate_cache("get_threads", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
         self._attempt_to_invalidate_cache("_get_state_group_for_event", None)
 
@@ -457,6 +473,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
         self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
         self._attempt_to_invalidate_cache("get_room_version_id", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
         # And delete state caches.
 
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5188b2f7a4..62bc4600fb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -30,6 +30,7 @@ from typing import (
     Iterable,
     List,
     Mapping,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -72,10 +73,18 @@ logger = logging.getLogger(__name__)
 
 _T = TypeVar("_T")
 
-
 MAX_STATE_DELTA_HOPS = 100
 
 
+# Freeze so it's immutable and we can use it as a cache value
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Sentinel:
+    pass
+
+
+ROOM_UNKNOWN_SENTINEL = Sentinel()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class EventMetadata:
     """Returned by `get_metadata_for_events`"""
@@ -300,51 +309,189 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @cached(max_entries=10000)
     async def get_room_type(self, room_id: str) -> Optional[str]:
-        """Get the room type for a given room. The server must be joined to the
-        given room.
-        """
-
-        row = await self.db_pool.simple_select_one(
-            table="room_stats_state",
-            keyvalues={"room_id": room_id},
-            retcols=("room_type",),
-            allow_none=True,
-            desc="get_room_type",
-        )
-
-        if row is not None:
-            return row[0]
-
-        # If we haven't updated `room_stats_state` with the room yet, query the
-        # create event directly.
-        create_event = await self.get_create_event_for_room(room_id)
-        room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
-        return room_type
+        raise NotImplementedError()
 
     @cachedList(cached_method_name="get_room_type", list_name="room_ids")
     async def bulk_get_room_type(
         self, room_ids: Set[str]
-    ) -> Mapping[str, Optional[str]]:
-        """Bulk fetch room types for the given rooms, the server must be in all
-        the rooms given.
+    ) -> Mapping[str, Union[Optional[str], Sentinel]]:
         """
+        Bulk fetch room types for the given rooms (via current state).
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_stats_state",
-            column="room_id",
-            iterable=room_ids,
-            retcols=("room_id", "room_type"),
-            desc="bulk_get_room_type",
+        Since this function is cached, any missing values would be cached as `None`. In
+        order to distinguish between an unencrypted room that has `None` encryption and
+        a room that is unknown to the server where we might want to omit the value
+        (which would make it cached as `None`), instead we use the sentinel value
+        `ROOM_UNKNOWN_SENTINEL`.
+
+        Returns:
+            A mapping from room ID to the room's type (`None` is a valid room type).
+            Rooms unknown to this server will return `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        def txn(
+            txn: LoggingTransaction,
+        ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "room_id", room_ids
+            )
+
+            # We can't rely on `room_stats_state.room_type` if the server has left the
+            # room because the `room_id` will still be in the table but everything will
+            # be set to `None` but `None` is a valid room type value. We join against
+            # the `room_stats_current` table which keeps track of the
+            # `current_state_events` count (and a proxy value `local_users_in_room`
+            # which can used to assume the server is participating in the room and has
+            # current state) to ensure that the data in `room_stats_state` is up-to-date
+            # with the current state.
+            #
+            # FIXME: Use `room_stats_current.current_state_events` instead of
+            # `room_stats_current.local_users_in_room` once
+            # https://github.com/element-hq/synapse/issues/17457 is fixed.
+            sql = f"""
+                SELECT room_id, room_type
+                FROM room_stats_state
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    {clause}
+                    AND local_users_in_room > 0
+            """
+
+            txn.execute(sql, args)
+
+            room_id_to_type_map = {}
+            for row in txn:
+                room_id_to_type_map[row[0]] = row[1]
+
+            return room_id_to_type_map
+
+        results = await self.db_pool.runInteraction(
+            "bulk_get_room_type",
+            txn,
         )
 
         # If we haven't updated `room_stats_state` with the room yet, query the
         # create events directly. This should happen only rarely so we don't
         # mind if we do this in a loop.
-        results = dict(rows)
         for room_id in room_ids - results.keys():
-            create_event = await self.get_create_event_for_room(room_id)
-            room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
-            results[room_id] = room_type
+            try:
+                create_event = await self.get_create_event_for_room(room_id)
+                room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+                results[room_id] = room_type
+            except NotFoundError:
+                # We use the sentinel value to distinguish between `None` which is a
+                # valid room type and a room that is unknown to the server so the value
+                # is just unset.
+                results[room_id] = ROOM_UNKNOWN_SENTINEL
+
+        return results
+
+    @cached(max_entries=10000)
+    async def get_room_encryption(self, room_id: str) -> Optional[str]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_room_encryption", list_name="room_ids")
+    async def bulk_get_room_encryption(
+        self, room_ids: Set[str]
+    ) -> Mapping[str, Union[Optional[str], Sentinel]]:
+        """
+        Bulk fetch room encryption for the given rooms (via current state).
+
+        Since this function is cached, any missing values would be cached as `None`. In
+        order to distinguish between an unencrypted room that has `None` encryption and
+        a room that is unknown to the server where we might want to omit the value
+        (which would make it cached as `None`), instead we use the sentinel value
+        `ROOM_UNKNOWN_SENTINEL`.
+
+        Returns:
+            A mapping from room ID to the room's encryption algorithm if the room is
+            encrypted, otherwise `None`. Rooms unknown to this server will return
+            `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        def txn(
+            txn: LoggingTransaction,
+        ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "room_id", room_ids
+            )
+
+            # We can't rely on `room_stats_state.encryption` if the server has left the
+            # room because the `room_id` will still be in the table but everything will
+            # be set to `None` but `None` is a valid encryption value. We join against
+            # the `room_stats_current` table which keeps track of the
+            # `current_state_events` count (and a proxy value `local_users_in_room`
+            # which can used to assume the server is participating in the room and has
+            # current state) to ensure that the data in `room_stats_state` is up-to-date
+            # with the current state.
+            #
+            # FIXME: Use `room_stats_current.current_state_events` instead of
+            # `room_stats_current.local_users_in_room` once
+            # https://github.com/element-hq/synapse/issues/17457 is fixed.
+            sql = f"""
+                SELECT room_id, encryption
+                FROM room_stats_state
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    {clause}
+                    AND local_users_in_room > 0
+            """
+
+            txn.execute(sql, args)
+
+            room_id_to_encryption_map = {}
+            for row in txn:
+                room_id_to_encryption_map[row[0]] = row[1]
+
+            return room_id_to_encryption_map
+
+        results = await self.db_pool.runInteraction(
+            "bulk_get_room_encryption",
+            txn,
+        )
+
+        # If we haven't updated `room_stats_state` with the room yet, query the state
+        # directly. This should happen only rarely so we don't mind if we do this in a
+        # loop.
+        encryption_event_ids: List[str] = []
+        for room_id in room_ids - results.keys():
+            state_map = await self.get_partial_filtered_current_state_ids(
+                room_id,
+                state_filter=StateFilter.from_types(
+                    [
+                        (EventTypes.Create, ""),
+                        (EventTypes.RoomEncryption, ""),
+                    ]
+                ),
+            )
+            # We can use the create event as a canary to tell whether the server has
+            # seen the room before
+            create_event_id = state_map.get((EventTypes.Create, ""))
+            encryption_event_id = state_map.get((EventTypes.RoomEncryption, ""))
+
+            if create_event_id is None:
+                # We use the sentinel value to distinguish between `None` which is a
+                # valid room type and a room that is unknown to the server so the value
+                # is just unset.
+                results[room_id] = ROOM_UNKNOWN_SENTINEL
+                continue
+
+            if encryption_event_id is None:
+                results[room_id] = None
+            else:
+                encryption_event_ids.append(encryption_event_id)
+
+        encryption_event_map = await self.get_events(encryption_event_ids)
+
+        for encryption_event_id in encryption_event_ids:
+            encryption_event = encryption_event_map.get(encryption_event_id)
+            # If the curent state says there is an encryption event, we should have it
+            # in the database.
+            assert encryption_event is not None
+
+            results[encryption_event.room_id] = encryption_event.content.get(
+                EventContentFields.ENCRYPTION_ALGORITHM
+            )
 
         return results