diff --git a/changelog.d/16505.misc b/changelog.d/16505.misc
new file mode 100644
index 0000000000..bd7cdd42af
--- /dev/null
+++ b/changelog.d/16505.misc
@@ -0,0 +1 @@
+Reduce memory allocations.
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6a8f8f2fd1..370f4041fb 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -103,10 +103,10 @@ class DeactivateAccountHandler:
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
- for threepid in bound_threepids:
+ for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
- user_id, threepid["medium"], threepid["address"], id_server
+ user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e9a544e754..62f2454f5d 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -1206,10 +1206,7 @@ class SsoHandler:
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
- for device in devices:
- user_id = device["user_id"]
- device_id = device["device_id"]
-
+ for user_id, device_id in devices:
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 81f661160c..774d5c12f0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -606,13 +606,16 @@ class DatabasePool:
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = await self.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
+ updates = cast(
+ List[Tuple[str]],
+ await self.simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ ),
)
- background_update_names = [x["update_name"] for x in updates]
+ background_update_names = [x[0] for x in updates]
for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
@@ -1804,9 +1807,9 @@ class DatabasePool:
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows, returning the result as a list of tuples.
Args:
table: the table name
@@ -1817,8 +1820,7 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
Returns:
- A list of dictionaries, one per result row, each a mapping between the
- column names from `retcols` and that column's value for the row.
+ A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
@@ -1836,9 +1838,9 @@ class DatabasePool:
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
+ more rows, returning the result as a list of tuples.
Args:
txn: Transaction object
@@ -1849,8 +1851,7 @@ class DatabasePool:
retcols: the names of the columns to return
Returns:
- A list of dictionaries, one per result row, each a mapping between the
- column names from `retcols` and that column's value for the row.
+ A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1863,7 +1864,7 @@ class DatabasePool:
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
- return cls.cursor_to_dict(txn)
+ return txn.fetchall()
async def simple_select_many_batch(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 84ef8136c2..d7482a1f4e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -286,16 +286,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_room_txn(
txn: LoggingTransaction,
- ) -> Dict[str, JsonDict]:
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "room_account_data",
- {"user_id": user_id, "room_id": room_id},
- ["account_data_type", "content"],
+ ) -> Dict[str, JsonMapping]:
+ rows = cast(
+ List[Tuple[str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=["account_data_type", "content"],
+ ),
)
return {
- row["account_data_type"]: db_to_json(row["content"]) for row in rows
+ account_data_type: db_to_json(content)
+ for account_data_type, content in rows
}
return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 073a99cd84..fa7d1c469a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -197,16 +197,21 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A list of ApplicationServices, which may be empty.
"""
- results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state.value}, ["as_id"]
+ results = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="application_services_state",
+ keyvalues={"state": state.value},
+ retcols=("as_id",),
+ ),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []
- for res in results:
+ for (as_id,) in results:
for service in as_list:
- if service.id == res["as_id"]:
+ if service.id == as_id:
services.append(service)
return services
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8be1511859..c006129625 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -508,21 +508,24 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
if device_id is not None:
keyvalues["device_id"] = device_id
- res = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues=keyvalues,
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ res = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
)
return {
- (d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
- user_id=d["user_id"],
- device_id=d["device_id"],
- ip=d["ip"],
- user_agent=d["user_agent"],
- last_seen=d["last_seen"],
+ (user_id, device_id): DeviceLastConnectionInfo(
+ user_id=user_id,
+ device_id=device_id,
+ ip=ip,
+ user_agent=user_agent,
+ last_seen=last_seen,
)
- for d in res
+ for user_id, ip, user_agent, device_id, last_seen in res
}
async def _get_user_ip_and_agents_from_database(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fc23d18eba..0b75f6763a 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
allow_none=True,
)
- async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
+ async def get_devices_by_user(
+ self, user_id: str
+ ) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
- and "display_name" for each device.
+ and "display_name" for each device. Display name may be null.
"""
- devices = await self.db_pool.simple_select_list(
- table="devices",
- keyvalues={"user_id": user_id, "hidden": False},
- retcols=("user_id", "device_id", "display_name"),
- desc="get_devices_by_user",
+ devices = cast(
+ List[Tuple[str, str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={"user_id": user_id, "hidden": False},
+ retcols=("user_id", "device_id", "display_name"),
+ desc="get_devices_by_user",
+ ),
)
- return {d["device_id"]: d for d in devices}
+ return {
+ d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
+ for d in devices
+ }
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
- ) -> List[Dict[str, Any]]:
+ ) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
- return await self.db_pool.simple_select_list(
- table="device_auth_providers",
- keyvalues={
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- retcols=("user_id", "device_id"),
- desc="get_devices_by_auth_provider_session_id",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_auth_providers",
+ keyvalues={
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
+ },
+ retcols=("user_id", "device_id"),
+ desc="get_devices_by_auth_provider_session_id",
+ ),
)
@trace
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
- devices = await self.db_pool.simple_select_list(
- table="device_lists_remote_cache",
- keyvalues={"user_id": user_id},
- retcols=("device_id", "content"),
- desc="get_cached_devices_for_user",
+ devices = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("device_id", "content"),
+ desc="get_cached_devices_for_user",
+ ),
)
- return {
- device["device_id"]: db_to_json(device["content"]) for device in devices
- }
+ return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
@@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- row_tuples = cast(
+ rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
@@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)
-
- return {row[0] for row in row_tuples}
else:
rows = cast(
- List[Dict[str, str]],
+ List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
@@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
),
)
- return {row["user_id"] for row in rows}
+ return {row[0] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index aac4cfb054..ad904a26a6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
@@ -274,32 +274,41 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = await self.db_pool.simple_select_list(
- table="e2e_room_keys",
- keyvalues=keyvalues,
- retcols=(
- "user_id",
- "room_id",
- "session_id",
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
+ rows = cast(
+ List[Tuple[str, str, int, int, int, str]],
+ await self.db_pool.simple_select_list(
+ table="e2e_room_keys",
+ keyvalues=keyvalues,
+ retcols=(
+ "room_id",
+ "session_id",
+ "first_message_index",
+ "forwarded_count",
+ "is_verified",
+ "session_data",
+ ),
+ desc="get_e2e_room_keys",
),
- desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
- for row in rows:
- room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
- room_entry["sessions"][row["session_id"]] = {
- "first_message_index": row["first_message_index"],
- "forwarded_count": row["forwarded_count"],
+ for (
+ room_id,
+ session_id,
+ first_message_index,
+ forwarded_count,
+ is_verified,
+ session_data,
+ ) in rows:
+ room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
+ room_entry["sessions"][session_id] = {
+ "first_message_index": first_message_index,
+ "forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
- "is_verified": bool(row["is_verified"]),
- "session_data": db_to_json(row["session_data"]),
+ "is_verified": bool(is_verified),
+ "session_data": db_to_json(session_data),
}
return sessions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4f80ce75cc..f1b0991503 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1898,21 +1898,23 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# keeping only the forward extremities (i.e. the events not referenced
# by other events in the queue). We do this so that we can always
# backpaginate in all the events we have dropped.
- rows = await self.db_pool.simple_select_list(
- table="federation_inbound_events_staging",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "event_json"),
- desc="prune_staged_events_in_room_fetch",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ ),
)
# Find the set of events referenced by those in the queue, as well as
# collecting all the event IDs in the queue.
referenced_events: Set[str] = set()
seen_events: Set[str] = set()
- for row in rows:
- event_id = row["event_id"]
+ for event_id, event_json in rows:
seen_events.add(event_id)
- event_d = db_to_json(row["event_json"])
+ event_d = db_to_json(event_json)
# We don't bother parsing the dicts into full blown event objects,
# as that is needlessly expensive.
diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py
index 654f924019..60621edeef 100644
--- a/synapse/storage/databases/main/experimental_features.py
+++ b/synapse/storage/databases/main/experimental_features.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Dict, FrozenSet
+from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -42,13 +42,16 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
Returns:
the features currently enabled for the user
"""
- enabled = await self.db_pool.simple_select_list(
- "per_user_experimental_features",
- {"user_id": user_id, "enabled": True},
- ["feature"],
+ enabled = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_list(
+ table="per_user_experimental_features",
+ keyvalues={"user_id": user_id, "enabled": True},
+ retcols=("feature",),
+ ),
)
- return frozenset(feature["feature"] for feature in enabled)
+ return frozenset(feature[0] for feature in enabled)
async def set_features_for_user(
self,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ea797864b9..ce88772f9e 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -248,17 +248,20 @@ class KeyStore(CacheInvalidationWorkerStore):
If we have multiple entries for a given key ID, returns the most recent.
"""
- rows = await self.db_pool.simple_select_list(
- table="server_keys_json",
- keyvalues={"server_name": server_name},
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
+ rows = cast(
+ List[Tuple[str, str, int, int, Union[bytes, memoryview]]],
+ await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
),
- desc="get_server_keys_json_for_remote",
)
if not rows:
@@ -266,14 +269,14 @@ class KeyStore(CacheInvalidationWorkerStore):
# We sort the rows by ts_added_ms so that the most recently added entry
# will stomp over older entries in the dictionary.
- rows.sort(key=lambda r: r["ts_added_ms"])
+ rows.sort(key=lambda r: r[2])
return {
- row["key_id"]: FetchKeyResultForRemote(
+ key_id: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
- key_json=bytes(row["key_json"]),
- valid_until_ts=row["ts_valid_until_ms"],
- added_ts=row["ts_added_ms"],
+ key_json=bytes(key_json),
+ valid_until_ts=ts_valid_until_ms,
+ added_ts=ts_added_ms,
)
- for row in rows
+ for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows
}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2e6b176bd2..f82140b2e8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -437,25 +437,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
- rows = await self.db_pool.simple_select_list(
- "local_media_repository_thumbnails",
- {"media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "local_media_repository_thumbnails",
+ {"media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_local_media_thumbnails",
),
- desc="get_local_media_thumbnails",
)
return [
ThumbnailInfo(
- width=row["thumbnail_width"],
- height=row["thumbnail_height"],
- method=row["thumbnail_method"],
- type=row["thumbnail_type"],
- length=row["thumbnail_length"],
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
@@ -568,25 +567,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
) -> List[ThumbnailInfo]:
- rows = await self.db_pool.simple_select_list(
- "remote_media_cache_thumbnails",
- {"media_origin": origin, "media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "remote_media_cache_thumbnails",
+ {"media_origin": origin, "media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_remote_media_thumbnails",
),
- desc="get_remote_media_thumbnails",
)
return [
ThumbnailInfo(
- width=row["thumbnail_width"],
- height=row["thumbnail_height"],
- method=row["thumbnail_method"],
- type=row["thumbnail_type"],
- length=row["thumbnail_length"],
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)
for row in rows
]
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index f5356e7f80..22025eca56 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -179,46 +179,44 @@ class PushRulesWorkerStore(
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
- rows = await self.db_pool.simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
+ rows = cast(
+ List[Tuple[str, int, int, str, str]],
+ await self.db_pool.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_for_user",
),
- desc="get_push_rules_for_user",
)
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+ # Sort by highest priority_class, then highest priority.
+ rows.sort(key=lambda row: (-int(row[1]), -int(row[2])))
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(
- [
- (
- row["rule_id"],
- row["priority_class"],
- row["conditions"],
- row["actions"],
- )
- for row in rows
- ],
+ [(row[0], row[1], row[3], row[4]) for row in rows],
enabled_map,
self.hs.config.experimental,
)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
- results = await self.db_pool.simple_select_list(
- table="push_rules_enable",
- keyvalues={"user_name": user_id},
- retcols=("rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
+ results = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ await self.db_pool.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ ),
)
- return {r["rule_id"]: bool(r["enabled"]) for r in results}
+ return {r[0]: bool(r[1]) for r in results}
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c7eb7fc478..a6a1671bd6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -371,18 +371,20 @@ class PusherWorkerStore(SQLBaseStore):
async def get_throttle_params_by_room(
self, pusher_id: int
) -> Dict[str, ThrottleParams]:
- res = await self.db_pool.simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
+ res = cast(
+ List[Tuple[str, Optional[int], Optional[int]]],
+ await self.db_pool.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ ),
)
params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = ThrottleParams(
- row["last_sent_ts"],
- row["throttle_ms"],
+ for room_id, last_sent_ts, throttle_ms in res:
+ params_by_room[room_id] = ThrottleParams(
+ last_sent_ts or 0, throttle_ms or 0
)
return params_by_room
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9e8643ae4d..b0ef7be155 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Returns:
Tuples of (auth_provider, external_id)
"""
- res = await self.db_pool.simple_select_list(
- table="user_external_ids",
- keyvalues={"user_id": mxid},
- retcols=("auth_provider", "external_id"),
- desc="get_external_ids_by_user",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ ),
)
- return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self) -> int:
"""Counts all users registered on the homeserver."""
@@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
- results = await self.db_pool.simple_select_list(
- "user_threepids",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address", "validated_at", "added_at"],
- desc="user_get_threepids",
+ results = cast(
+ List[Tuple[str, str, int, int]],
+ await self.db_pool.simple_select_list(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
+ ),
)
- return [ThreepidResult(**r) for r in results]
+ return [
+ ThreepidResult(
+ medium=r[0],
+ address=r[1],
+ validated_at=r[2],
+ added_at=r[3],
+ )
+ for r in results
+ ]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
@@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="add_user_bound_threepid",
)
- async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
+ async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
@@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id: The ID of the user to retrieve threepids for
Returns:
- List of dictionaries containing the following keys:
- medium (str): The medium of the threepid (e.g "email")
- address (str): The address of the threepid (e.g "bob@example.com")
- """
- return await self.db_pool.simple_select_list(
- table="user_threepid_id_server",
- keyvalues={"user_id": user_id},
- retcols=["medium", "address"],
- desc="user_get_bound_threepids",
+ List of tuples of two strings:
+ medium: The medium of the threepid (e.g "email")
+ address: The address of the threepid (e.g "bob@example.com")
+ """
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ ),
)
async def remove_user_bound_threepid(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index ce7bfd5146..419b2c7a22 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -384,14 +384,17 @@ class RelationsWorkerStore(SQLBaseStore):
def get_all_relation_ids_for_event_txn(
txn: LoggingTransaction,
) -> List[str]:
- rows = self.db_pool.simple_select_list_txn(
- txn=txn,
- table="event_relations",
- keyvalues={"relates_to_id": event_id},
- retcols=["event_id"],
+ rows = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_list_txn(
+ txn=txn,
+ table="event_relations",
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ ),
)
- return [row["event_id"] for row in rows]
+ return [row[0] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9d24d2c347..3e8fcf1975 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1232,28 +1232,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
room_servers: Dict[str, PartialStateResyncInfo] = {}
- rows = await self.db_pool.simple_select_list(
- table="partial_state_rooms",
- keyvalues={},
- retcols=("room_id", "joined_via"),
- desc="get_server_which_served_partial_join",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- joined_via = row["joined_via"]
+ for room_id, joined_via in rows:
room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
- rows = await self.db_pool.simple_select_list(
- "partial_state_rooms_servers",
- keyvalues=None,
- retcols=("room_id", "server_name"),
- desc="get_partial_state_rooms",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ "partial_state_rooms_servers",
+ keyvalues=None,
+ retcols=("room_id", "server_name"),
+ desc="get_partial_state_rooms",
+ ),
)
- for row in rows:
- room_id = row["room_id"]
- server_name = row["server_name"]
+ for room_id, server_name in rows:
entry = room_servers.get(room_id)
if entry is None:
# There is a foreign key constraint which enforces that every room_id in
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3a87eba430..a1627dffb7 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1070,13 +1070,16 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
for fully-joined rooms.
"""
- rows = await self.db_pool.simple_select_list(
- "current_state_events",
- keyvalues={"room_id": room_id},
- retcols=("event_id", "membership"),
- desc="has_completed_background_updates",
+ rows = cast(
+ List[Tuple[str, Optional[str]]],
+ await self.db_pool.simple_select_list(
+ "current_state_events",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "membership"),
+ desc="has_completed_background_updates",
+ ),
)
- return {row["event_id"]: row["membership"] for row in rows}
+ return dict(rows)
# TODO This returns a mutable object, which is generally confusing when using a cache.
@cached(max_entries=10000) # type: ignore[synapse-@cached-mutable]
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 61403a98cf..7deda7790e 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -45,14 +45,17 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag content.
"""
- rows = await self.db_pool.simple_select_list(
- "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.simple_select_list(
+ "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
+ ),
)
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
+ for room_id, tag, content in rows:
+ room_tags = tags_by_room.setdefault(room_id, {})
+ room_tags[tag] = db_to_json(content)
return tags_by_room
async def get_all_updated_tags(
@@ -161,13 +164,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A mapping of tags to tag content.
"""
- rows = await self.db_pool.simple_select_list(
- table="room_tags",
- keyvalues={"user_id": user_id, "room_id": room_id},
- retcols=("tag", "content"),
- desc="get_tags_for_room",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="room_tags",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ retcols=("tag", "content"),
+ desc="get_tags_for_room",
+ ),
)
- return {row["tag"]: db_to_json(row["content"]) for row in rows}
+ return {tag: db_to_json(content) for tag, content in rows}
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 919c66f553..8ab7c42c4a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -169,13 +169,17 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db_pool.simple_select_list(
- table="ui_auth_sessions_credentials",
- keyvalues={"session_id": session_id},
- retcols=("stage_type", "result"),
- desc="get_completed_ui_auth_stages",
- ):
- results[row["stage_type"]] = db_to_json(row["result"])
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ),
+ )
+ for stage_type, result in rows:
+ results[stage_type] = db_to_json(result)
return results
@@ -295,13 +299,15 @@ class UIAuthWorkerStore(SQLBaseStore):
Returns:
List of user_agent/ip pairs
"""
- rows = await self.db_pool.simple_select_list(
- table="ui_auth_sessions_ips",
- keyvalues={"session_id": session_id},
- retcols=("user_agent", "ip"),
- desc="get_user_agents_ips_to_ui_auth_session",
+ return cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ ),
)
- return [(row["user_agent"], row["ip"]) for row in rows]
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 09d2a8c5b3..182e429174 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -154,16 +154,22 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db_pool.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
+ delta_ids = cast(
+ List[Tuple[str, str, str]],
+ self.db_pool.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ ),
)
return _GetStateGroupDelta(
prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ {
+ (event_type, state_key): event_id
+ for event_type, state_key, event_id in delta_ids
+ },
)
return await self.db_pool.runInteraction(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d11ded6c5b..76c56d5434 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
@@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- async def get_all_room_state(self) -> List[Dict[str, Any]]:
- return await self.store.db_pool.simple_select_list(
- "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
+ async def get_all_room_state(self) -> List[Optional[str]]:
+ rows = cast(
+ List[Tuple[Optional[str]]],
+ await self.store.db_pool.simple_select_list(
+ "room_stats_state", None, retcols=("topic",)
+ ),
)
+ return [r[0] for r in rows]
def _get_current_stats(
self, stats_type: str, stat_id: str
@@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 1)
- self.assertEqual(r[0]["topic"], "foo")
+ self.assertEqual(r[0], "foo")
def test_create_user(self) -> None:
"""
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index 71db47405e..98b01086bc 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None:
columns += expected_row.keys()
- rows = self.get_success(
+ row_tuples = self.get_success(
self.store.db_pool.simple_select_list(
table=table,
keyvalues={
@@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None:
self.assertEqual(
- len(rows),
+ len(row_tuples),
1,
f"Background update did not leave behind latest receipt in {table}",
)
self.assertEqual(
- rows[0],
- {
- "room_id": room_id,
- "receipt_type": receipt_type,
- "user_id": user_id,
- **expected_row,
- },
+ row_tuples[0],
+ (
+ room_id,
+ receipt_type,
+ user_id,
+ *expected_row.values(),
+ ),
)
else:
self.assertEqual(
- len(rows),
+ len(row_tuples),
0,
f"Background update did not remove all duplicate receipts from {table}",
)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8bbf936ae9..8cbc974ac4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import secrets
-from typing import Generator, Tuple
+from typing import Generator, List, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
@@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
)
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
- res = self.get_success(
- self.storage.db_pool.simple_select_list(
- self.table_name, None, ["id, username, value"]
- )
+ yield from cast(
+ List[Tuple[int, str, str]],
+ self.get_success(
+ self.storage.db_pool.simple_select_list(
+ self.table_name, None, ["id, username, value"]
+ )
+ ),
)
- for i in res:
- yield (i["id"], i["username"], i["value"])
-
def test_upsert_many(self) -> None:
"""
Upsert_many will perform the upsert operation across a batch of data.
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index abf7d0564d..3f5bfa09d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple, cast
from unittest.mock import AsyncMock, Mock
import yaml
@@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates()
# Check the correct values are in the new table.
- rows = self.get_success(
- self.store.db_pool.simple_select_list(
- table="test_constraint",
- keyvalues={},
- retcols=("a", "b"),
- )
+ rows = cast(
+ List[Tuple[int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="test_constraint",
+ keyvalues={},
+ retcols=("a", "b"),
+ )
+ ),
)
- self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+ self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected.
self.get_failure(
@@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates()
# Check the correct values are in the new table.
- rows = self.get_success(
- self.store.db_pool.simple_select_list(
- table="test_constraint",
- keyvalues={},
- retcols=("a", "b"),
- )
+ rows = cast(
+ List[Tuple[int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="test_constraint",
+ keyvalues={},
+ retcols=("a", "b"),
+ )
+ ),
)
- self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+ self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected.
self.get_failure(
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 256d28e4c9..e4a52c301e 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
- self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
+ self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield defer.ensureDeferred(
@@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
- self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
+ self.assertEqual([(1,), (2,), (3,)], ret)
self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0c054a598f..8e4393d843 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Dict, List, Optional, Tuple, cast
from unittest.mock import AsyncMock
from parameterized import parameterized
@@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
self.pump(0)
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
self.assertEqual(
- result,
- [
- {
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": None,
- "last_seen": 12345678000,
- }
- ],
+ result, [("access_token", "ip", "user_agent", None, 12345678000)]
)
# Add another & trigger the storage loop
@@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
self.pump(0)
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
# Only one result, has been upserted.
self.assertEqual(
- result,
- [
- {
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": None,
- "last_seen": 12345878000,
- }
- ],
+ result, [("access_token", "ip", "user_agent", None, 12345878000)]
)
@parameterized.expand([(False,), (True,)])
@@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
else:
# Check that the new IP and user agent has not been stored yet
- db_result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="devices",
- keyvalues={},
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ db_result = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ),
+ ),
),
)
- self.assertEqual(
- db_result,
- [
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": None,
- "user_agent": None,
- "last_seen": None,
- },
- ],
- )
+ self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
@@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# Check that the new IP and user agent has not been stored yet
- db_result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="devices",
- keyvalues={},
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ db_result = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
),
)
self.assertCountEqual(
db_result,
[
- {
- "user_id": user_id,
- "device_id": device_id_1,
- "ip": "ip_1",
- "user_agent": "user_agent_1",
- "last_seen": 12345678000,
- },
- {
- "user_id": user_id,
- "device_id": device_id_2,
- "ip": "ip_2",
- "user_agent": "user_agent_2",
- "last_seen": 12345678000,
- },
+ (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
+ (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
],
)
@@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# Check that the new IP and user agent has not been stored yet
- db_result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={},
- retcols=("access_token", "ip", "user_agent", "last_seen"),
+ db_result = cast(
+ List[Tuple[str, str, str, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=("access_token", "ip", "user_agent", "last_seen"),
+ ),
),
)
self.assertEqual(
db_result,
[
- {
- "access_token": "access_token",
- "ip": "ip_1",
- "user_agent": "user_agent_1",
- "last_seen": 12345678000,
- },
- {
- "access_token": "access_token",
- "ip": "ip_2",
- "user_agent": "user_agent_2",
- "last_seen": 12345678000,
- },
+ ("access_token", "ip_1", "user_agent_1", 12345678000),
+ ("access_token", "ip_2", "user_agent_2", 12345678000),
],
)
@@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
# We should see that in the DB
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
self.assertEqual(
result,
- [
- {
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": device_id,
- "last_seen": 0,
- }
- ],
+ [("access_token", "ip", "user_agent", device_id, 0)],
)
# Now advance by a couple of months
self.reactor.advance(60 * 24 * 60 * 60)
# We should get no results.
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
self.assertEqual(result, [])
@@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
# We should see that in the DB
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
# ensure user1 is filtered out
- self.assertEqual(
- result,
- [
- {
- "access_token": access_token2,
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": device_id2,
- "last_seen": 0,
- }
- ],
- )
+ self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f4c4661aaf..36fcab06b5 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple, cast
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import Membership
@@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- res = self.get_success(
- self.store.db_pool.simple_select_list(
- "room_memberships",
- {"user_id": "@alice:test"},
- ["display_name", "event_id"],
- )
+ res = cast(
+ List[Tuple[Optional[str], str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ ),
)
# Check that we only got one result back
self.assertEqual(len(res), 1)
# Check that alice's display name is "alice"
- self.assertEqual(res[0]["display_name"], "alice")
+ self.assertEqual(res[0][0], "alice")
# Grab the event_id to use later
- event_id = res[0]["event_id"]
+ event_id = res[0][1]
# Create a profile with the offending null byte in the display name
new_profile = {"displayname": "ali\u0000ce"}
@@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
tok=self.t_alice,
)
- res2 = self.get_success(
- self.store.db_pool.simple_select_list(
- "room_memberships",
- {"user_id": "@alice:test"},
- ["display_name", "event_id"],
- )
+ res2 = cast(
+ List[Tuple[Optional[str], str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ ),
)
# Check that we only have two results
self.assertEqual(len(res2), 2)
# Filter out the previous event using the event_id we grabbed above
- row = [row for row in res2 if row["event_id"] != event_id]
+ row = [row for row in res2 if row[1] != event_id]
# Check that alice's display name is now None
- self.assertEqual(row[0]["display_name"], None)
+ self.assertIsNone(row[0][0])
def test_room_is_locally_forgotten(self) -> None:
"""Test that when the last local user has forgotten a room it is known as forgotten."""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b9446c36c..2715c73f16 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple, cast
from immutabledict import immutabledict
@@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
)
# check that only state events are in state_groups, and all state events are in state_groups
- res = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_groups",
- keyvalues=None,
- retcols=("event_id",),
- )
+ res = cast(
+ List[Tuple[str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups",
+ keyvalues=None,
+ retcols=("event_id",),
+ )
+ ),
)
events = []
for result in res:
- self.assertNotIn(event3.event_id, result)
- events.append(result.get("event_id"))
+ self.assertNotIn(event3.event_id, result) # XXX
+ events.append(result[0])
for event, _ in processed_events_and_context:
if event.is_state():
@@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
# has an entry and prev event in state_group_edges
for event, context in processed_events_and_context:
if event.is_state():
- state = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_groups_state",
- keyvalues={"state_group": context.state_group_after_event},
- retcols=("type", "state_key"),
- )
- )
- self.assertEqual(event.type, state[0].get("type"))
- self.assertEqual(event.state_key, state[0].get("state_key"))
-
- groups = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_group_edges",
- keyvalues={"state_group": str(context.state_group_after_event)},
- retcols=("*",),
- )
+ state = cast(
+ List[Tuple[str, str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups_state",
+ keyvalues={"state_group": context.state_group_after_event},
+ retcols=("type", "state_key"),
+ )
+ ),
)
- self.assertEqual(
- context.state_group_before_event, groups[0].get("prev_state_group")
+ self.assertEqual(event.type, state[0][0])
+ self.assertEqual(event.state_key, state[0][1])
+
+ groups = cast(
+ List[Tuple[str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_group_edges",
+ keyvalues={
+ "state_group": str(context.state_group_after_event)
+ },
+ retcols=("prev_state_group",),
+ )
+ ),
)
+ self.assertEqual(context.state_group_before_event, groups[0][0])
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 8c72aa1722..822c41dd9f 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
-from typing import Any, Dict, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple, cast
from unittest import mock
from unittest.mock import Mock, patch
@@ -62,14 +62,13 @@ class GetUserDirectoryTables:
Returns a list of tuples (user_id, room_id) where room_id is public and
contains the user with the given id.
"""
- r = await self.store.db_pool.simple_select_list(
- "users_in_public_rooms", None, ("user_id", "room_id")
+ r = cast(
+ List[Tuple[str, str]],
+ await self.store.db_pool.simple_select_list(
+ "users_in_public_rooms", None, ("user_id", "room_id")
+ ),
)
-
- retval = set()
- for i in r:
- retval.add((i["user_id"], i["room_id"]))
- return retval
+ return set(r)
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table.
@@ -78,27 +77,30 @@ class GetUserDirectoryTables:
to the rows of `users_who_share_private_rooms`.
"""
- rows = await self.store.db_pool.simple_select_list(
- "users_who_share_private_rooms",
- None,
- ["user_id", "other_user_id", "room_id"],
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.store.db_pool.simple_select_list(
+ "users_who_share_private_rooms",
+ None,
+ ["user_id", "other_user_id", "room_id"],
+ ),
)
- rv = set()
- for row in rows:
- rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
- return rv
+ return set(rows)
async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table.
This is useful when checking we've correctly excluded users from the directory.
"""
- result = await self.store.db_pool.simple_select_list(
- "user_directory",
- None,
- ["user_id"],
+ result = cast(
+ List[Tuple[str]],
+ await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ["user_id"],
+ ),
)
- return {row["user_id"] for row in result}
+ return {row[0] for row in result}
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
"""Fetch users and their profiles from the `user_directory` table.
@@ -107,16 +109,17 @@ class GetUserDirectoryTables:
It's almost the entire contents of the `user_directory` table: the only
thing missing is an unused room_id column.
"""
- rows = await self.store.db_pool.simple_select_list(
- "user_directory",
- None,
- ("user_id", "display_name", "avatar_url"),
+ rows = cast(
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ("user_id", "display_name", "avatar_url"),
+ ),
)
return {
- row["user_id"]: ProfileInfo(
- display_name=row["display_name"], avatar_url=row["avatar_url"]
- )
- for row in rows
+ user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
+ for user_id, display_name, avatar_url in rows
}
async def get_tables(
|