summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py215
1 files changed, 181 insertions, 34 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5188b2f7a4..62bc4600fb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -30,6 +30,7 @@ from typing import (
     Iterable,
     List,
     Mapping,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -72,10 +73,18 @@ logger = logging.getLogger(__name__)
 
 _T = TypeVar("_T")
 
-
 MAX_STATE_DELTA_HOPS = 100
 
 
+# Freeze so it's immutable and we can use it as a cache value
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Sentinel:
+    pass
+
+
+ROOM_UNKNOWN_SENTINEL = Sentinel()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class EventMetadata:
     """Returned by `get_metadata_for_events`"""
@@ -300,51 +309,189 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @cached(max_entries=10000)
     async def get_room_type(self, room_id: str) -> Optional[str]:
-        """Get the room type for a given room. The server must be joined to the
-        given room.
-        """
-
-        row = await self.db_pool.simple_select_one(
-            table="room_stats_state",
-            keyvalues={"room_id": room_id},
-            retcols=("room_type",),
-            allow_none=True,
-            desc="get_room_type",
-        )
-
-        if row is not None:
-            return row[0]
-
-        # If we haven't updated `room_stats_state` with the room yet, query the
-        # create event directly.
-        create_event = await self.get_create_event_for_room(room_id)
-        room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
-        return room_type
+        raise NotImplementedError()
 
     @cachedList(cached_method_name="get_room_type", list_name="room_ids")
     async def bulk_get_room_type(
         self, room_ids: Set[str]
-    ) -> Mapping[str, Optional[str]]:
-        """Bulk fetch room types for the given rooms, the server must be in all
-        the rooms given.
+    ) -> Mapping[str, Union[Optional[str], Sentinel]]:
         """
+        Bulk fetch room types for the given rooms (via current state).
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_stats_state",
-            column="room_id",
-            iterable=room_ids,
-            retcols=("room_id", "room_type"),
-            desc="bulk_get_room_type",
+        Since this function is cached, any missing values would be cached as `None`. In
+        order to distinguish between an unencrypted room that has `None` encryption and
+        a room that is unknown to the server where we might want to omit the value
+        (which would make it cached as `None`), instead we use the sentinel value
+        `ROOM_UNKNOWN_SENTINEL`.
+
+        Returns:
+            A mapping from room ID to the room's type (`None` is a valid room type).
+            Rooms unknown to this server will return `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        def txn(
+            txn: LoggingTransaction,
+        ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "room_id", room_ids
+            )
+
+            # We can't rely on `room_stats_state.room_type` if the server has left the
+            # room because the `room_id` will still be in the table but everything will
+            # be set to `None` but `None` is a valid room type value. We join against
+            # the `room_stats_current` table which keeps track of the
+            # `current_state_events` count (and a proxy value `local_users_in_room`
+            # which can used to assume the server is participating in the room and has
+            # current state) to ensure that the data in `room_stats_state` is up-to-date
+            # with the current state.
+            #
+            # FIXME: Use `room_stats_current.current_state_events` instead of
+            # `room_stats_current.local_users_in_room` once
+            # https://github.com/element-hq/synapse/issues/17457 is fixed.
+            sql = f"""
+                SELECT room_id, room_type
+                FROM room_stats_state
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    {clause}
+                    AND local_users_in_room > 0
+            """
+
+            txn.execute(sql, args)
+
+            room_id_to_type_map = {}
+            for row in txn:
+                room_id_to_type_map[row[0]] = row[1]
+
+            return room_id_to_type_map
+
+        results = await self.db_pool.runInteraction(
+            "bulk_get_room_type",
+            txn,
         )
 
         # If we haven't updated `room_stats_state` with the room yet, query the
         # create events directly. This should happen only rarely so we don't
         # mind if we do this in a loop.
-        results = dict(rows)
         for room_id in room_ids - results.keys():
-            create_event = await self.get_create_event_for_room(room_id)
-            room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
-            results[room_id] = room_type
+            try:
+                create_event = await self.get_create_event_for_room(room_id)
+                room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+                results[room_id] = room_type
+            except NotFoundError:
+                # We use the sentinel value to distinguish between `None` which is a
+                # valid room type and a room that is unknown to the server so the value
+                # is just unset.
+                results[room_id] = ROOM_UNKNOWN_SENTINEL
+
+        return results
+
+    @cached(max_entries=10000)
+    async def get_room_encryption(self, room_id: str) -> Optional[str]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_room_encryption", list_name="room_ids")
+    async def bulk_get_room_encryption(
+        self, room_ids: Set[str]
+    ) -> Mapping[str, Union[Optional[str], Sentinel]]:
+        """
+        Bulk fetch room encryption for the given rooms (via current state).
+
+        Since this function is cached, any missing values would be cached as `None`. In
+        order to distinguish between an unencrypted room that has `None` encryption and
+        a room that is unknown to the server where we might want to omit the value
+        (which would make it cached as `None`), instead we use the sentinel value
+        `ROOM_UNKNOWN_SENTINEL`.
+
+        Returns:
+            A mapping from room ID to the room's encryption algorithm if the room is
+            encrypted, otherwise `None`. Rooms unknown to this server will return
+            `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        def txn(
+            txn: LoggingTransaction,
+        ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "room_id", room_ids
+            )
+
+            # We can't rely on `room_stats_state.encryption` if the server has left the
+            # room because the `room_id` will still be in the table but everything will
+            # be set to `None` but `None` is a valid encryption value. We join against
+            # the `room_stats_current` table which keeps track of the
+            # `current_state_events` count (and a proxy value `local_users_in_room`
+            # which can used to assume the server is participating in the room and has
+            # current state) to ensure that the data in `room_stats_state` is up-to-date
+            # with the current state.
+            #
+            # FIXME: Use `room_stats_current.current_state_events` instead of
+            # `room_stats_current.local_users_in_room` once
+            # https://github.com/element-hq/synapse/issues/17457 is fixed.
+            sql = f"""
+                SELECT room_id, encryption
+                FROM room_stats_state
+                INNER JOIN room_stats_current USING (room_id)
+                WHERE
+                    {clause}
+                    AND local_users_in_room > 0
+            """
+
+            txn.execute(sql, args)
+
+            room_id_to_encryption_map = {}
+            for row in txn:
+                room_id_to_encryption_map[row[0]] = row[1]
+
+            return room_id_to_encryption_map
+
+        results = await self.db_pool.runInteraction(
+            "bulk_get_room_encryption",
+            txn,
+        )
+
+        # If we haven't updated `room_stats_state` with the room yet, query the state
+        # directly. This should happen only rarely so we don't mind if we do this in a
+        # loop.
+        encryption_event_ids: List[str] = []
+        for room_id in room_ids - results.keys():
+            state_map = await self.get_partial_filtered_current_state_ids(
+                room_id,
+                state_filter=StateFilter.from_types(
+                    [
+                        (EventTypes.Create, ""),
+                        (EventTypes.RoomEncryption, ""),
+                    ]
+                ),
+            )
+            # We can use the create event as a canary to tell whether the server has
+            # seen the room before
+            create_event_id = state_map.get((EventTypes.Create, ""))
+            encryption_event_id = state_map.get((EventTypes.RoomEncryption, ""))
+
+            if create_event_id is None:
+                # We use the sentinel value to distinguish between `None` which is a
+                # valid room type and a room that is unknown to the server so the value
+                # is just unset.
+                results[room_id] = ROOM_UNKNOWN_SENTINEL
+                continue
+
+            if encryption_event_id is None:
+                results[room_id] = None
+            else:
+                encryption_event_ids.append(encryption_event_id)
+
+        encryption_event_map = await self.get_events(encryption_event_ids)
+
+        for encryption_event_id in encryption_event_ids:
+            encryption_event = encryption_event_map.get(encryption_event_id)
+            # If the curent state says there is an encryption event, we should have it
+            # in the database.
+            assert encryption_event is not None
+
+            results[encryption_event.room_id] = encryption_event.content.get(
+                EventContentFields.ENCRYPTION_ALGORITHM
+            )
 
         return results