diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9f3804a504..fc23d18eba 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_extremeties",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id", "stream_id"),
- desc="get_device_list_last_stream_id_for_remotes",
+ rows = cast(
+ List[Tuple[str, str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id"),
+ desc="get_device_list_last_stream_id_for_remotes",
+ ),
)
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
- results.update({row["user_id"]: row["stream_id"] for row in rows})
+ results.update(rows)
return results
@@ -1077,22 +1080,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = await self.db_pool.simple_select_many_batch(
- table="device_lists_remote_resync",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ row_tuples = cast(
+ List[Tuple[str]],
+ await self.db_pool.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ ),
)
+
+ return {row[0] for row in row_tuples}
else:
- rows = await self.db_pool.simple_select_list(
- table="device_lists_remote_resync",
- keyvalues=None,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync",
+ rows = cast(
+ List[Dict[str, str]],
+ await self.db_pool.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ ),
)
- return {row["user_id"] for row in rows}
+ return {row["user_id"] for row in rows}
async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
|