diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index afb880532e..ef26d5d9d3 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
+ async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]:
"""Retrieve a room.
Args:
room_id: The ID of the room to retrieve.
Returns:
- A dict containing the room information, or None if the room is unknown.
+ A tuple containing the room information:
+ * True if the room is public
+ * True if the room has an auth chain index
+
+ or None if the room is unknown.
"""
- return await self.db_pool.simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
- desc="get_room",
- allow_none=True,
+ row = cast(
+ Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("is_public", "has_auth_chain_index"),
+ desc="get_room",
+ allow_none=True,
+ ),
)
+ if row is None:
+ return row
+ return bool(row[0]), bool(row[1])
async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
@@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
if row:
- return RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- )
+ return RatelimitOverride(messages_per_second=row[0], burst_count=row[1])
else:
return None
@@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
join.
"""
- result = await self.db_pool.simple_select_one(
- table="partial_state_rooms",
- keyvalues={"room_id": room_id},
- retcols=("join_event_id", "device_lists_stream_id"),
- desc="get_join_event_id_for_partial_state",
+ return cast(
+ Tuple[str, int],
+ await self.db_pool.simple_select_one(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("join_event_id", "device_lists_stream_id"),
+ desc="get_join_event_id_for_partial_state",
+ ),
)
- return result["join_event_id"], result["device_lists_stream_id"]
def get_un_partial_stated_rooms_token(self, instance_name: str) -> int:
return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer(
|