diff --git a/changelog.d/16564.misc b/changelog.d/16564.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16564.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2c2baeac67..d06f8e3296 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -283,7 +283,7 @@ class AdminHandler:
start, limit, user_id
)
for media in media_ids:
- writer.write_media_id(media["media_id"], media)
+ writer.write_media_id(media.media_id, attr.asdict(media))
logger.info(
"[%s] Written %d media_ids of %s",
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 36e2db8975..2947e154be 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.storage.databases.main.room import LargestRoomStats
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
@@ -170,26 +171,24 @@ class RoomListHandler:
ignore_non_federatable=from_federation,
)
- def build_room_entry(room: JsonDict) -> JsonDict:
+ def build_room_entry(room: LargestRoomStats) -> JsonDict:
entry = {
- "room_id": room["room_id"],
- "name": room["name"],
- "topic": room["topic"],
- "canonical_alias": room["canonical_alias"],
- "num_joined_members": room["joined_members"],
- "avatar_url": room["avatar"],
- "world_readable": room["history_visibility"]
+ "room_id": room.room_id,
+ "name": room.name,
+ "topic": room.topic,
+ "canonical_alias": room.canonical_alias,
+ "num_joined_members": room.joined_members,
+ "avatar_url": room.avatar,
+ "world_readable": room.history_visibility
== HistoryVisibility.WORLD_READABLE,
- "guest_can_join": room["guest_access"] == "can_join",
- "join_rule": room["join_rules"],
- "room_type": room["room_type"],
+ "guest_can_join": room.guest_access == "can_join",
+ "join_rule": room.join_rules,
+ "room_type": room.room_type,
}
# Filter out Nones – rather omit the field altogether
return {k: v for k, v in entry.items() if v is not None}
- results = [build_room_entry(r) for r in results]
-
response: JsonDict = {}
num_results = len(results)
if limit is not None:
@@ -212,33 +211,33 @@ class RoomListHandler:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
- last_joined_members=initial_entry["num_joined_members"],
- last_room_id=initial_entry["room_id"],
+ last_joined_members=initial_entry.joined_members,
+ last_room_id=initial_entry.room_id,
direction_is_forward=False,
).to_token()
if more_to_come:
response["next_batch"] = RoomListNextBatch(
- last_joined_members=final_entry["num_joined_members"],
- last_room_id=final_entry["room_id"],
+ last_joined_members=final_entry.joined_members,
+ last_room_id=final_entry.room_id,
direction_is_forward=True,
).to_token()
else:
if has_batch_token:
response["next_batch"] = RoomListNextBatch(
- last_joined_members=final_entry["num_joined_members"],
- last_room_id=final_entry["room_id"],
+ last_joined_members=final_entry.joined_members,
+ last_room_id=final_entry.room_id,
direction_is_forward=True,
).to_token()
if more_to_come:
response["prev_batch"] = RoomListNextBatch(
- last_joined_members=initial_entry["num_joined_members"],
- last_room_id=initial_entry["room_id"],
+ last_joined_members=initial_entry.joined_members,
+ last_room_id=initial_entry.room_id,
direction_is_forward=False,
).to_token()
- response["chunk"] = results
+ response["chunk"] = [build_room_entry(r) for r in results]
response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple,
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index dd559b4c45..1dfb12e065 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -703,24 +703,24 @@ class RoomSummaryHandler:
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
- entry = {
- "room_id": stats["room_id"],
- "name": stats["name"],
- "topic": stats["topic"],
- "canonical_alias": stats["canonical_alias"],
- "num_joined_members": stats["joined_members"],
- "avatar_url": stats["avatar"],
- "join_rule": stats["join_rules"],
+ entry: JsonDict = {
+ "room_id": stats.room_id,
+ "name": stats.name,
+ "topic": stats.topic,
+ "canonical_alias": stats.canonical_alias,
+ "num_joined_members": stats.joined_members,
+ "avatar_url": stats.avatar,
+ "join_rule": stats.join_rules,
"world_readable": (
- stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
+ stats.history_visibility == HistoryVisibility.WORLD_READABLE
),
- "guest_can_join": stats["guest_access"] == "can_join",
- "room_type": stats["room_type"],
+ "guest_can_join": stats.guest_access == "can_join",
+ "room_type": stats.room_type,
}
if self._msc3266_enabled:
- entry["im.nheko.summary.version"] = stats["version"]
- entry["im.nheko.summary.encryption"] = stats["encryption"]
+ entry["im.nheko.summary.version"] = stats.version
+ entry["im.nheko.summary.encryption"] = stats.encryption
# Federation requests need to provide additional information so the
# requested server is able to filter the response appropriately.
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index b7637dff0b..8cf5268854 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,6 +17,8 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
+import attr
+
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
@@ -418,7 +420,7 @@ class UserMediaRestServlet(RestServlet):
start, limit, user_id, order_by, direction
)
- ret = {"media": media, "total": total}
+ ret = {"media": [attr.asdict(m) for m in media], "total": total}
if (start + limit) < total:
ret["next_token"] = start + len(media)
@@ -477,7 +479,7 @@ class UserMediaRestServlet(RestServlet):
)
deleted_media, total = await self.media_repository.delete_local_media_ids(
- [row["media_id"] for row in media]
+ [m.media_id for m in media]
)
return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total}
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index ffce92d45e..f3e06d3da3 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -77,7 +77,18 @@ class ListRegistrationTokensRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
valid = parse_boolean(request, "valid")
token_list = await self.store.get_registration_tokens(valid)
- return HTTPStatus.OK, {"registration_tokens": token_list}
+ return HTTPStatus.OK, {
+ "registration_tokens": [
+ {
+ "token": t[0],
+ "uses_allowed": t[1],
+ "pending": t[2],
+ "completed": t[3],
+ "expiry_time": t[4],
+ }
+ for t in token_list
+ ]
+ }
class NewRegistrationTokenRestServlet(RestServlet):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 0659f22a89..23a034522c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -16,6 +16,8 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse
+import attr
+
from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter
@@ -306,10 +308,13 @@ class RoomRestServlet(RestServlet):
raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id)
- ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
- ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
+ result = attr.asdict(ret)
+ result["joined_local_devices"] = await self.store.count_devices_by_users(
+ members
+ )
+ result["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
- return HTTPStatus.OK, ret
+ return HTTPStatus.OK, result
async def on_DELETE(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7fe16130e7..73878dd99d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -18,6 +18,8 @@ import secrets
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -161,11 +163,13 @@ class UsersRestServletV2(RestServlet):
)
# If support for MSC3866 is not enabled, don't show the approval flag.
+ filter = None
if not self._msc3866_enabled:
- for user in users:
- del user["approved"]
- ret = {"users": users, "total": total}
+ def _filter(a: attr.Attribute) -> bool:
+ return a.name != "approved"
+
+ ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total}
if (start + limit) < total:
ret["next_token"] = str(start + len(users))
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 12829d3d7d..7426dbcad6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -28,6 +28,7 @@ from typing import (
Sequence,
Tuple,
Type,
+ cast,
)
import attr
@@ -488,14 +489,14 @@ class BackgroundUpdater:
True if we have finished running all the background updates, otherwise False
"""
- def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
+ def get_background_updates_txn(txn: Cursor) -> List[Tuple[str, Optional[str]]]:
txn.execute(
"""
SELECT update_name, depends_on FROM background_updates
ORDER BY ordering, update_name
"""
)
- return self.db_pool.cursor_to_dict(txn)
+ return cast(List[Tuple[str, Optional[str]]], txn.fetchall())
if not self._current_background_update:
all_pending_updates = await self.db_pool.runInteraction(
@@ -507,14 +508,13 @@ class BackgroundUpdater:
return True
# find the first update which isn't dependent on another one in the queue.
- pending = {update["update_name"] for update in all_pending_updates}
- for upd in all_pending_updates:
- depends_on = upd["depends_on"]
+ pending = {update_name for update_name, depends_on in all_pending_updates}
+ for update_name, depends_on in all_pending_updates:
if not depends_on or depends_on not in pending:
break
logger.info(
"Not starting on bg update %s until %s is done",
- upd["update_name"],
+ update_name,
depends_on,
)
else:
@@ -524,7 +524,7 @@ class BackgroundUpdater:
"another: dependency cycle?"
)
- self._current_background_update = upd["update_name"]
+ self._current_background_update = update_name
# We have a background update to run, otherwise we would have returned
# early.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a4e7048368..6d54bb0eb2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -18,7 +18,6 @@ import logging
import time
import types
from collections import defaultdict
-from sys import intern
from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
@@ -1042,20 +1041,6 @@ class DatabasePool:
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
- @staticmethod
- def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
- """Converts a SQL cursor into an list of dicts.
-
- Args:
- cursor: The DBAPI cursor which has executed a query.
- Returns:
- A list of dicts where the key is the column header.
- """
- assert cursor.description is not None, "cursor.description was None"
- col_headers = [intern(str(column[0])) for column in cursor.description]
- results = [dict(zip(col_headers, row)) for row in cursor]
- return results
-
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 840d725114..89f4077351 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,8 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
+import attr
+
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
from synapse.storage._base import make_in_list_sql_clause
@@ -28,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import get_domain_from_id
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@@ -82,6 +84,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPaginateResponse:
+ """This is very similar to UserInfo, but not quite the same."""
+
+ name: str
+ user_type: Optional[str]
+ is_guest: bool
+ admin: bool
+ deactivated: bool
+ shadow_banned: bool
+ displayname: Optional[str]
+ avatar_url: Optional[str]
+ creation_ts: Optional[int]
+ approved: bool
+ erased: bool
+ last_seen_ts: int
+ locked: bool
+
+
class DataStore(
EventsBackgroundUpdatesStore,
ExperimentalFeaturesStore,
@@ -156,7 +177,7 @@ class DataStore(
approved: bool = True,
not_user_types: Optional[List[str]] = None,
locked: bool = False,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[UserPaginateResponse], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
@@ -182,7 +203,7 @@ class DataStore(
def get_users_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[UserPaginateResponse], int]:
filters = []
args: list = []
@@ -282,13 +303,24 @@ class DataStore(
"""
args += [limit, start]
txn.execute(sql, args)
- users = self.db_pool.cursor_to_dict(txn)
-
- # some of those boolean values are returned as integers when we're on SQLite
- columns_to_boolify = ["erased"]
- for user in users:
- for column in columns_to_boolify:
- user[column] = bool(user[column])
+ users = [
+ UserPaginateResponse(
+ name=row[0],
+ user_type=row[1],
+ is_guest=bool(row[2]),
+ admin=bool(row[3]),
+ deactivated=bool(row[4]),
+ shadow_banned=bool(row[5]),
+ displayname=row[6],
+ avatar_url=row[7],
+ creation_ts=row[8],
+ approved=bool(row[9]),
+ erased=bool(row[10]),
+ last_seen_ts=row[11],
+ locked=bool(row[12]),
+ )
+ for row in txn
+ ]
return users, count
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 49edbb9e06..b0811a4cf1 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1620,7 +1620,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
#
# For each duplicate, we delete all the existing rows and put one back.
- KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
last_row = progress.get(
"last_row",
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
@@ -1628,44 +1627,62 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause(
- [(x, last_row[x]) for x in KEY_COLS]
+ [
+ ("stream_id", last_row["stream_id"]),
+ ("destination", last_row["destination"]),
+ ("user_id", last_row["user_id"]),
+ ("device_id", last_row["device_id"]),
+ ]
)
- sql = """
+ sql = f"""
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
FROM device_lists_outbound_pokes
- WHERE %s
- GROUP BY %s
+ WHERE {clause}
+ GROUP BY stream_id, destination, user_id, device_id
HAVING count(*) > 1
- ORDER BY %s
+ ORDER BY stream_id, destination, user_id, device_id
LIMIT ?
- """ % (
- clause, # WHERE
- ",".join(KEY_COLS), # GROUP BY
- ",".join(KEY_COLS), # ORDER BY
- )
+ """
txn.execute(sql, args + [batch_size])
- rows = self.db_pool.cursor_to_dict(txn)
+ rows = txn.fetchall()
- row = None
- for row in rows:
+ stream_id, destination, user_id, device_id = None, None, None, None
+ for stream_id, destination, user_id, device_id, _ in rows:
self.db_pool.simple_delete_txn(
txn,
"device_lists_outbound_pokes",
- {x: row[x] for x in KEY_COLS},
+ {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ },
)
- row["sent"] = False
self.db_pool.simple_insert_txn(
txn,
"device_lists_outbound_pokes",
- row,
+ {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ "sent": False,
+ },
)
- if row:
+ if rows:
self.db_pool.updates._background_update_progress_txn(
txn,
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
- {"last_row": row},
+ {
+ "last_row": {
+ "stream_id": stream_id,
+ "destination": destination,
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ },
)
return len(rows)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index aeb3db596c..c8d7c9fd32 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+import attr
+
from synapse.api.constants import Direction
from synapse.logging.opentracing import trace
from synapse.media._base import ThumbnailInfo
@@ -45,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LocalMedia:
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+ safe_from_quarantine: bool
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -180,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -197,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
# Set ordering
order_by_column = MediaSortOrder(order_by).value
@@ -217,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = """
SELECT
- "media_id",
- "media_type",
- "media_length",
- "upload_name",
- "created_ts",
- "last_access_ts",
- "quarantined_by",
- "safe_from_quarantine"
+ media_id,
+ media_type,
+ media_length,
+ upload_name,
+ created_ts,
+ last_access_ts,
+ quarantined_by,
+ safe_from_quarantine
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@@ -236,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
args += [limit, start]
txn.execute(sql, args)
- media = self.db_pool.cursor_to_dict(txn)
+ media = [
+ LocalMedia(
+ media_id=row[0],
+ media_type=row[1],
+ media_length=row[2],
+ upload_name=row[3],
+ created_ts=row[4],
+ last_access_ts=row[5],
+ quarantined_by=row[6],
+ safe_from_quarantine=bool(row[7]),
+ )
+ for row in txn
+ ]
return media, count
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e09ab21593..933d76e905 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1517,7 +1517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_registration_tokens(
self, valid: Optional[bool] = None
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
"""List all registration tokens. Used by the admin API.
Args:
@@ -1526,34 +1526,48 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Default is None: return all tokens regardless of validity.
Returns:
- A list of dicts, each containing details of a token.
+ A list of tuples containing:
+ * The token
+ * The number of users allowed (or None)
+ * Whether it is pending
+ * Whether it has been completed
+ * An expiry time (or None if no expiry)
"""
def select_registration_tokens_txn(
txn: LoggingTransaction, now: int, valid: Optional[bool]
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, Optional[int], int, int, Optional[int]]]:
if valid is None:
# Return all tokens regardless of validity
- txn.execute("SELECT * FROM registration_tokens")
+ txn.execute(
+ """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ """
+ )
elif valid:
# Select valid tokens only
- sql = (
- "SELECT * FROM registration_tokens WHERE "
- "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
- "AND (expiry_time > ? OR expiry_time IS NULL)"
- )
+ sql = """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ WHERE (uses_allowed > pending + completed OR uses_allowed IS NULL)
+ AND (expiry_time > ? OR expiry_time IS NULL)
+ """
txn.execute(sql, [now])
else:
# Select invalid tokens only
- sql = (
- "SELECT * FROM registration_tokens WHERE "
- "uses_allowed <= pending + completed OR expiry_time <= ?"
- )
+ sql = """
+ SELECT token, uses_allowed, pending, completed, expiry_time
+ FROM registration_tokens
+ WHERE uses_allowed <= pending + completed OR expiry_time <= ?
+ """
txn.execute(sql, [now])
- return self.db_pool.cursor_to_dict(txn)
+ return cast(
+ List[Tuple[str, Optional[int], int, int, Optional[int]]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"select_registration_tokens",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3e8fcf1975..6d4b9891e7 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -78,6 +78,31 @@ class RatelimitOverride:
burst_count: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LargestRoomStats:
+ room_id: str
+ name: Optional[str]
+ canonical_alias: Optional[str]
+ joined_members: int
+ join_rules: Optional[str]
+ guest_access: Optional[str]
+ history_visibility: Optional[str]
+ state_events: int
+ avatar: Optional[str]
+ topic: Optional[str]
+ room_type: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomStats(LargestRoomStats):
+ joined_local_members: int
+ version: Optional[str]
+ creator: Optional[str]
+ encryption: Optional[str]
+ federatable: bool
+ public: bool
+
+
class RoomSortOrder(Enum):
"""
Enum to define the sorting method used when returning rooms with get_rooms_paginate
@@ -204,7 +229,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
allow_none=True,
)
- async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
+ async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]:
"""Retrieve room with statistics.
Args:
@@ -215,7 +240,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_room_with_stats_txn(
txn: LoggingTransaction, room_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[RoomStats]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -229,15 +254,28 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE room_id = ?
"""
txn.execute(sql, [room_id])
- # Catch error if sql returns empty result to return "None" instead of an error
- try:
- res = self.db_pool.cursor_to_dict(txn)[0]
- except IndexError:
+ row = txn.fetchone()
+ if not row:
return None
-
- res["federatable"] = bool(res["federatable"])
- res["public"] = bool(res["public"])
- return res
+ return RoomStats(
+ room_id=row[0],
+ name=row[1],
+ canonical_alias=row[2],
+ joined_members=row[3],
+ joined_local_members=row[4],
+ version=row[5],
+ creator=row[6],
+ encryption=row[7],
+ federatable=bool(row[8]),
+ public=bool(row[9]),
+ join_rules=row[10],
+ guest_access=row[11],
+ history_visibility=row[12],
+ state_events=row[13],
+ avatar=row[14],
+ topic=row[15],
+ room_type=row[16],
+ )
return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
@@ -368,7 +406,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ) -> List[Dict[str, Any]]:
+ ) -> List[LargestRoomStats]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@@ -505,20 +543,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _get_largest_public_rooms_txn(
txn: LoggingTransaction,
- ) -> List[Dict[str, Any]]:
+ ) -> List[LargestRoomStats]:
txn.execute(sql, query_args)
- results = self.db_pool.cursor_to_dict(txn)
+ results = [
+ LargestRoomStats(
+ room_id=r[0],
+ name=r[1],
+ canonical_alias=r[3],
+ joined_members=r[4],
+ join_rules=r[8],
+ guest_access=r[7],
+ history_visibility=r[6],
+ state_events=0,
+ avatar=r[5],
+ topic=r[2],
+ room_type=r[9],
+ )
+ for r in txn
+ ]
if not forwards:
results.reverse()
return results
- ret_val = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
- return ret_val
@cached(max_entries=10000)
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e9fbf32c7c..032b89d684 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -342,10 +342,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
- self.assertFalse(room["federatable"])
- self.assertFalse(room["public"])
- self.assertEqual(room["join_rules"], "public")
- self.assertIsNone(room["guest_access"])
+ self.assertFalse(room.federatable)
+ self.assertFalse(room.public)
+ self.assertEqual(room.join_rules, "public")
+ self.assertIsNone(room.guest_access)
# The user should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -372,7 +372,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
- self.assertEqual(room["join_rules"], "public")
+ self.assertEqual(room.join_rules, "public")
# Both users should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(inviter))
@@ -411,9 +411,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
- self.assertFalse(room["public"])
- self.assertEqual(room["join_rules"], "invite")
- self.assertEqual(room["guest_access"], "can_join")
+ self.assertFalse(room.public)
+ self.assertEqual(room.join_rules, "invite")
+ self.assertEqual(room.guest_access, "can_join")
# Both users should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(inviter))
@@ -455,9 +455,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
- self.assertFalse(room["public"])
- self.assertEqual(room["join_rules"], "invite")
- self.assertEqual(room["guest_access"], "can_join")
+ self.assertFalse(room.public)
+ self.assertEqual(room.join_rules, "invite")
+ self.assertEqual(room.guest_access, "can_join")
# Both users should be in the room.
rooms = self.get_success(self.store.get_rooms_for_user(inviter))
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index b8823d6993..01c0e5e671 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -39,11 +39,11 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(1, total)
- self.assertEqual(self.displayname, users.pop()["displayname"])
+ self.assertEqual(self.displayname, users.pop().displayname)
users, total = self.get_success(
self.store.get_users_paginate(0, 10, name="BC", guests=False)
)
self.assertEqual(1, total)
- self.assertEqual(self.displayname, users.pop()["displayname"])
+ self.assertEqual(self.displayname, users.pop().displayname)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1e27f2c275..ce34195a25 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -59,14 +59,9 @@ class RoomStoreTestCase(HomeserverTestCase):
def test_get_room_with_stats(self) -> None:
res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
assert res is not None
- self.assertLessEqual(
- {
- "room_id": self.room.to_string(),
- "creator": self.u_creator.to_string(),
- "public": True,
- }.items(),
- res.items(),
- )
+ self.assertEqual(res.room_id, self.room.to_string())
+ self.assertEqual(res.creator, self.u_creator.to_string())
+ self.assertTrue(res.public)
def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone(
|