diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 26b8e1a172..63624f3e8f 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -268,13 +268,23 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
if data.type == EventTypes.Member:
- self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
+ self._attempt_to_invalidate_cache(
+ "get_rooms_for_user", (data.state_key,)
+ )
+ elif data.type == EventTypes.RoomEncryption:
+ self._attempt_to_invalidate_cache(
+ "get_room_encryption", (data.room_id,)
+ )
+ elif data.type == EventTypes.Create:
+ self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is
# unfortunate for the membership caches, but should recover quickly.
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
- self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
+ self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+ self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
+ self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
@@ -345,6 +355,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_forgotten_rooms_for_user", (state_key,)
)
+ elif etype == EventTypes.Create:
+ self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+ elif etype == EventTypes.RoomEncryption:
+ self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
if relates_to:
self._attempt_to_invalidate_cache(
@@ -405,6 +419,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_thread_summary", None)
self._attempt_to_invalidate_cache("get_thread_participated", None)
self._attempt_to_invalidate_cache("get_threads", (room_id,))
+ self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+ self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
self._attempt_to_invalidate_cache("_get_state_group_for_event", None)
@@ -457,6 +473,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_room_version_id", (room_id,))
+ self._attempt_to_invalidate_cache("get_room_type", (room_id,))
+ self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
# And delete state caches.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5188b2f7a4..62bc4600fb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -30,6 +30,7 @@ from typing import (
Iterable,
List,
Mapping,
+ MutableMapping,
Optional,
Set,
Tuple,
@@ -72,10 +73,18 @@ logger = logging.getLogger(__name__)
_T = TypeVar("_T")
-
MAX_STATE_DELTA_HOPS = 100
+# Freeze so it's immutable and we can use it as a cache value
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Sentinel:
+ pass
+
+
+ROOM_UNKNOWN_SENTINEL = Sentinel()
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""
@@ -300,51 +309,189 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@cached(max_entries=10000)
async def get_room_type(self, room_id: str) -> Optional[str]:
- """Get the room type for a given room. The server must be joined to the
- given room.
- """
-
- row = await self.db_pool.simple_select_one(
- table="room_stats_state",
- keyvalues={"room_id": room_id},
- retcols=("room_type",),
- allow_none=True,
- desc="get_room_type",
- )
-
- if row is not None:
- return row[0]
-
- # If we haven't updated `room_stats_state` with the room yet, query the
- # create event directly.
- create_event = await self.get_create_event_for_room(room_id)
- room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
- return room_type
+ raise NotImplementedError()
@cachedList(cached_method_name="get_room_type", list_name="room_ids")
async def bulk_get_room_type(
self, room_ids: Set[str]
- ) -> Mapping[str, Optional[str]]:
- """Bulk fetch room types for the given rooms, the server must be in all
- the rooms given.
+ ) -> Mapping[str, Union[Optional[str], Sentinel]]:
"""
+ Bulk fetch room types for the given rooms (via current state).
- rows = await self.db_pool.simple_select_many_batch(
- table="room_stats_state",
- column="room_id",
- iterable=room_ids,
- retcols=("room_id", "room_type"),
- desc="bulk_get_room_type",
+ Since this function is cached, any missing values would be cached as `None`. In
+ order to distinguish between an unencrypted room that has `None` encryption and
+ a room that is unknown to the server where we might want to omit the value
+ (which would make it cached as `None`), instead we use the sentinel value
+ `ROOM_UNKNOWN_SENTINEL`.
+
+ Returns:
+ A mapping from room ID to the room's type (`None` is a valid room type).
+ Rooms unknown to this server will return `ROOM_UNKNOWN_SENTINEL`.
+ """
+
+ def txn(
+ txn: LoggingTransaction,
+ ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "room_id", room_ids
+ )
+
+ # We can't rely on `room_stats_state.room_type` if the server has left the
+ # room because the `room_id` will still be in the table but everything will
+ # be set to `None` but `None` is a valid room type value. We join against
+ # the `room_stats_current` table which keeps track of the
+ # `current_state_events` count (and a proxy value `local_users_in_room`
+ # which can used to assume the server is participating in the room and has
+ # current state) to ensure that the data in `room_stats_state` is up-to-date
+ # with the current state.
+ #
+ # FIXME: Use `room_stats_current.current_state_events` instead of
+ # `room_stats_current.local_users_in_room` once
+ # https://github.com/element-hq/synapse/issues/17457 is fixed.
+ sql = f"""
+ SELECT room_id, room_type
+ FROM room_stats_state
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ {clause}
+ AND local_users_in_room > 0
+ """
+
+ txn.execute(sql, args)
+
+ room_id_to_type_map = {}
+ for row in txn:
+ room_id_to_type_map[row[0]] = row[1]
+
+ return room_id_to_type_map
+
+ results = await self.db_pool.runInteraction(
+ "bulk_get_room_type",
+ txn,
)
# If we haven't updated `room_stats_state` with the room yet, query the
# create events directly. This should happen only rarely so we don't
# mind if we do this in a loop.
- results = dict(rows)
for room_id in room_ids - results.keys():
- create_event = await self.get_create_event_for_room(room_id)
- room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
- results[room_id] = room_type
+ try:
+ create_event = await self.get_create_event_for_room(room_id)
+ room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+ results[room_id] = room_type
+ except NotFoundError:
+ # We use the sentinel value to distinguish between `None` which is a
+ # valid room type and a room that is unknown to the server so the value
+ # is just unset.
+ results[room_id] = ROOM_UNKNOWN_SENTINEL
+
+ return results
+
+ @cached(max_entries=10000)
+ async def get_room_encryption(self, room_id: str) -> Optional[str]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_room_encryption", list_name="room_ids")
+ async def bulk_get_room_encryption(
+ self, room_ids: Set[str]
+ ) -> Mapping[str, Union[Optional[str], Sentinel]]:
+ """
+ Bulk fetch room encryption for the given rooms (via current state).
+
+ Since this function is cached, any missing values would be cached as `None`. In
+ order to distinguish between an unencrypted room that has `None` encryption and
+ a room that is unknown to the server where we might want to omit the value
+ (which would make it cached as `None`), instead we use the sentinel value
+ `ROOM_UNKNOWN_SENTINEL`.
+
+ Returns:
+ A mapping from room ID to the room's encryption algorithm if the room is
+ encrypted, otherwise `None`. Rooms unknown to this server will return
+ `ROOM_UNKNOWN_SENTINEL`.
+ """
+
+ def txn(
+ txn: LoggingTransaction,
+ ) -> MutableMapping[str, Union[Optional[str], Sentinel]]:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "room_id", room_ids
+ )
+
+ # We can't rely on `room_stats_state.encryption` if the server has left the
+ # room because the `room_id` will still be in the table but everything will
+ # be set to `None` but `None` is a valid encryption value. We join against
+ # the `room_stats_current` table which keeps track of the
+ # `current_state_events` count (and a proxy value `local_users_in_room`
+ # which can used to assume the server is participating in the room and has
+ # current state) to ensure that the data in `room_stats_state` is up-to-date
+ # with the current state.
+ #
+ # FIXME: Use `room_stats_current.current_state_events` instead of
+ # `room_stats_current.local_users_in_room` once
+ # https://github.com/element-hq/synapse/issues/17457 is fixed.
+ sql = f"""
+ SELECT room_id, encryption
+ FROM room_stats_state
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ {clause}
+ AND local_users_in_room > 0
+ """
+
+ txn.execute(sql, args)
+
+ room_id_to_encryption_map = {}
+ for row in txn:
+ room_id_to_encryption_map[row[0]] = row[1]
+
+ return room_id_to_encryption_map
+
+ results = await self.db_pool.runInteraction(
+ "bulk_get_room_encryption",
+ txn,
+ )
+
+ # If we haven't updated `room_stats_state` with the room yet, query the state
+ # directly. This should happen only rarely so we don't mind if we do this in a
+ # loop.
+ encryption_event_ids: List[str] = []
+ for room_id in room_ids - results.keys():
+ state_map = await self.get_partial_filtered_current_state_ids(
+ room_id,
+ state_filter=StateFilter.from_types(
+ [
+ (EventTypes.Create, ""),
+ (EventTypes.RoomEncryption, ""),
+ ]
+ ),
+ )
+ # We can use the create event as a canary to tell whether the server has
+ # seen the room before
+ create_event_id = state_map.get((EventTypes.Create, ""))
+ encryption_event_id = state_map.get((EventTypes.RoomEncryption, ""))
+
+ if create_event_id is None:
+ # We use the sentinel value to distinguish between `None` which is a
+ # valid room type and a room that is unknown to the server so the value
+ # is just unset.
+ results[room_id] = ROOM_UNKNOWN_SENTINEL
+ continue
+
+ if encryption_event_id is None:
+ results[room_id] = None
+ else:
+ encryption_event_ids.append(encryption_event_id)
+
+ encryption_event_map = await self.get_events(encryption_event_ids)
+
+ for encryption_event_id in encryption_event_ids:
+ encryption_event = encryption_event_map.get(encryption_event_id)
+ # If the curent state says there is an encryption event, we should have it
+ # in the database.
+ assert encryption_event is not None
+
+ results[encryption_event.room_id] = encryption_event.content.get(
+ EventContentFields.ENCRYPTION_ALGORITHM
+ )
return results
|