summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/databases/main/cache.py22
-rw-r--r--synapse/storage/databases/main/state.py215
3 files changed, 205 insertions, 36 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 066f3d08ae..e12ab94576 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,6 +127,8 @@ class SQLBaseStore(metaclass=ABCMeta):
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
         self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
     def _invalidate_state_caches_all(self, room_id: str) -> None:
         """Invalidates caches that are based on the current state, but does
@@ -153,6 +155,8 @@ class SQLBaseStore(metaclass=ABCMeta):
             "_get_rooms_for_local_user_where_membership_is_inner", None
         )
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 26b8e1a172..63624f3e8f 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -268,13 +268,23 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)  # type: ignore[attr-defined]
 
             if data.type == EventTypes.Member:
-                self.get_rooms_for_user.invalidate((data.state_key,))  # type: ignore[attr-defined]
+                self._attempt_to_invalidate_cache(
+                    "get_rooms_for_user", (data.state_key,)
+                )
+            elif data.type == EventTypes.RoomEncryption:
+                self._attempt_to_invalidate_cache(
+                    "get_room_encryption", (data.room_id,)
+                )
+            elif data.type == EventTypes.Create:
+                self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
         elif row.type == EventsStreamAllStateRow.TypeId:
             assert isinstance(data, EventsStreamAllStateRow)
             # Similar to the above, but the entire caches are invalidated. This is
             # unfortunate for the membership caches, but should recover quickly.
             self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)  # type: ignore[attr-defined]
-            self.get_rooms_for_user.invalidate_all()  # type: ignore[attr-defined]
+            self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+            self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
+            self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
         else:
             raise Exception("Unknown events stream row type %s" % (row.type,))
 
@@ -345,6 +355,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._attempt_to_invalidate_cache(
                 "get_forgotten_rooms_for_user", (state_key,)
             )
+        elif etype == EventTypes.Create:
+            self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        elif etype == EventTypes.RoomEncryption:
+            self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
         if relates_to:
             self._attempt_to_invalidate_cache(
@@ -405,6 +419,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_thread_summary", None)
         self._attempt_to_invalidate_cache("get_thread_participated", None)
         self._attempt_to_invalidate_cache("get_threads", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
         self._attempt_to_invalidate_cache("_get_state_group_for_event", None)
 
@@ -457,6 +473,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
         self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
         self._attempt_to_invalidate_cache("get_room_version_id", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
         # And delete state caches.
 
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5188b2f7a4..62bc4600fb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -30,6 +30,7 @@ from typing import (
     Iterable,
     List,
     Mapping,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -72,10 +73,18 @@ logger = logging.getLogger(__name__)
 
 _T = TypeVar("_T")
 
-
 MAX_STATE_DELTA_HOPS = 100
 
 
+# Freeze so it's immutable and we can use it as a cache value
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Sentinel:
+    pass
+
+
+ROOM_UNKNOWN_SENTINEL = Sentinel()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class EventMetadata:
     """Returned by `get_metadata_for_events`"""
@@ -300,51 +309,189 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @cached(max_entries=10000)
     async def get_room_type(self, room_id: str) -> Optional[str]:
-        """Get the room type for a given room. The server must be joined to the
-        given room.
-        """
-
-        row = await self.db_pool.simple_select_one(
-            table="room_stats_state",
-            keyvalues={"room_id": room_id},
-            retcols=("room_type",),
-            allow_none=True,
-            desc="get_room_type",
-        )
-
-        if row is not None:
-            return row[0]
-
-        # If we haven't updated `room_stats_state` with the room yet, query the
-        # create event directly.
-        create_event = await self.get_create_event_for_room(room_id)
-        room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
-        return room_type
+        raise NotImplementedError()
 
     @cachedList(cached_method_name="get_room_type", list_name="room_ids")
     async def bulk_get_room_type(
         self, room_ids: Set[str]
-    ) -> Mapping[str, Optional[str]]:
-        """Bulk fetch room types for the given rooms, the server must be in all
-        the rooms given.
+    ) -> Mapping[str, Union[Optional[str], Sentinel]]:
         """
+        Bulk fetch room types for the given rooms (via current state).
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_stats_state",
-            column="room_id",
-            iterable=room_ids,
-            retcols=("room_id", "room_type"),
-            desc="bulk_get_room_type",
+        Since this function is cached, any missing values would be cached as `None`. In
+        order to distinguish between an unencrypted room that has `None` encryption and
+        a room that is unknown to the server where we might want to omit the value
+        (which would make it cached as `None`), instead we use the sentinel value
+        `ROOM_UNKNOWN_SENTINEL`.
+
+        Returns:
+            A mapping from room ID to the room's type (`None` is a valid room type).
+            Rooms unknown to this server will return `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        def txn(
+            txn: LoggingTransaction,
+        ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "room_id", room_ids
+            )
+
+            # We can't rely on `room_stats_state.room_type` if the server has left the
+            # room because the `room_id` will still be in the table but everything will
+            # be set to `None` but `None` is a valid room type value. We join against
+            # the `room_stats_current` table which keeps track of the
+            # `current_state_events` count (and a proxy value `local_users_in_room`
+            # which can used to assume the server is participating in the room and has
+            # current state) to ensure that the data in `room_stats_state` is up-to-date
+            # with the current state.
+            #
+            # FIXME: Use `room_stats_current.current_state_events` instead of
+            # `room_stats_current.local_users_in_room` once
+            # https://github.com/element-hq/synapse/issues/17457 is fixed.
+            sql = f"""
+                SELECT room_id, room_type
+                FROM room_stats_state
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    {clause}
+                    AND local_users_in_room > 0
+            """
+
+            txn.execute(sql, args)
+
+            room_id_to_type_map = {}
+            for row in txn:
+                room_id_to_type_map[row[0]] = row[1]
+
+            return room_id_to_type_map
+
+        results = await self.db_pool.runInteraction(
+            "bulk_get_room_type",
+            txn,
         )
 
         # If we haven't updated `room_stats_state` with the room yet, query the
         # create events directly. This should happen only rarely so we don't
         # mind if we do this in a loop.
-        results = dict(rows)
         for room_id in room_ids - results.keys():
-            create_event = await self.get_create_event_for_room(room_id)
-            room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
-            results[room_id] = room_type
+            try:
+                create_event = await self.get_create_event_for_room(room_id)
+                room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+                results[room_id] = room_type
+            except NotFoundError:
+                # We use the sentinel value to distinguish between `None` which is a
+                # valid room type and a room that is unknown to the server so the value
+                # is just unset.
+                results[room_id] = ROOM_UNKNOWN_SENTINEL
+
+        return results
+
+    @cached(max_entries=10000)
+    async def get_room_encryption(self, room_id: str) -> Optional[str]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_room_encryption", list_name="room_ids")
+    async def bulk_get_room_encryption(
+        self, room_ids: Set[str]
+    ) -> Mapping[str, Union[Optional[str], Sentinel]]:
+        """
+        Bulk fetch room encryption for the given rooms (via current state).
+
+        Since this function is cached, any missing values would be cached as `None`. In
+        order to distinguish between an unencrypted room that has `None` encryption and
+        a room that is unknown to the server where we might want to omit the value
+        (which would make it cached as `None`), instead we use the sentinel value
+        `ROOM_UNKNOWN_SENTINEL`.
+
+        Returns:
+            A mapping from room ID to the room's encryption algorithm if the room is
+            encrypted, otherwise `None`. Rooms unknown to this server will return
+            `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        def txn(
+            txn: LoggingTransaction,
+        ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "room_id", room_ids
+            )
+
+            # We can't rely on `room_stats_state.encryption` if the server has left the
+            # room because the `room_id` will still be in the table but everything will
+            # be set to `None` but `None` is a valid encryption value. We join against
+            # the `room_stats_current` table which keeps track of the
+            # `current_state_events` count (and a proxy value `local_users_in_room`
+            # which can used to assume the server is participating in the room and has
+            # current state) to ensure that the data in `room_stats_state` is up-to-date
+            # with the current state.
+            #
+            # FIXME: Use `room_stats_current.current_state_events` instead of
+            # `room_stats_current.local_users_in_room` once
+            # https://github.com/element-hq/synapse/issues/17457 is fixed.
+            sql = f"""
+                SELECT room_id, encryption
+                FROM room_stats_state
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    {clause}
+                    AND local_users_in_room > 0
+            """
+
+            txn.execute(sql, args)
+
+            room_id_to_encryption_map = {}
+            for row in txn:
+                room_id_to_encryption_map[row[0]] = row[1]
+
+            return room_id_to_encryption_map
+
+        results = await self.db_pool.runInteraction(
+            "bulk_get_room_encryption",
+            txn,
+        )
+
+        # If we haven't updated `room_stats_state` with the room yet, query the state
+        # directly. This should happen only rarely so we don't mind if we do this in a
+        # loop.
+        encryption_event_ids: List[str] = []
+        for room_id in room_ids - results.keys():
+            state_map = await self.get_partial_filtered_current_state_ids(
+                room_id,
+                state_filter=StateFilter.from_types(
+                    [
+                        (EventTypes.Create, ""),
+                        (EventTypes.RoomEncryption, ""),
+                    ]
+                ),
+            )
+            # We can use the create event as a canary to tell whether the server has
+            # seen the room before
+            create_event_id = state_map.get((EventTypes.Create, ""))
+            encryption_event_id = state_map.get((EventTypes.RoomEncryption, ""))
+
+            if create_event_id is None:
+                # We use the sentinel value to distinguish between `None` which is a
+                # valid room type and a room that is unknown to the server so the value
+                # is just unset.
+                results[room_id] = ROOM_UNKNOWN_SENTINEL
+                continue
+
+            if encryption_event_id is None:
+                results[room_id] = None
+            else:
+                encryption_event_ids.append(encryption_event_id)
+
+        encryption_event_map = await self.get_events(encryption_event_ids)
+
+        for encryption_event_id in encryption_event_ids:
+            encryption_event = encryption_event_map.get(encryption_event_id)
+            # If the curent state says there is an encryption event, we should have it
+            # in the database.
+            assert encryption_event is not None
+
+            results[encryption_event.room_id] = encryption_event.content.get(
+                EventContentFields.ENCRYPTION_ALGORITHM
+            )
 
         return results