summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py49
1 files changed, 30 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9f3804a504..fc23d18eba 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1052,16 +1052,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     async def get_device_list_last_stream_id_for_remotes(
         self, user_ids: Iterable[str]
     ) -> Mapping[str, Optional[str]]:
-        rows = await self.db_pool.simple_select_many_batch(
-            table="device_lists_remote_extremeties",
-            column="user_id",
-            iterable=user_ids,
-            retcols=("user_id", "stream_id"),
-            desc="get_device_list_last_stream_id_for_remotes",
+        rows = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="device_lists_remote_extremeties",
+                column="user_id",
+                iterable=user_ids,
+                retcols=("user_id", "stream_id"),
+                desc="get_device_list_last_stream_id_for_remotes",
+            ),
         )
 
         results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
-        results.update({row["user_id"]: row["stream_id"] for row in rows})
+        results.update(rows)
 
         return results
 
@@ -1077,22 +1080,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             The IDs of users whose device lists need resync.
         """
         if user_ids:
-            rows = await self.db_pool.simple_select_many_batch(
-                table="device_lists_remote_resync",
-                column="user_id",
-                iterable=user_ids,
-                retcols=("user_id",),
-                desc="get_user_ids_requiring_device_list_resync_with_iterable",
+            row_tuples = cast(
+                List[Tuple[str]],
+                await self.db_pool.simple_select_many_batch(
+                    table="device_lists_remote_resync",
+                    column="user_id",
+                    iterable=user_ids,
+                    retcols=("user_id",),
+                    desc="get_user_ids_requiring_device_list_resync_with_iterable",
+                ),
             )
+
+            return {row[0] for row in row_tuples}
         else:
-            rows = await self.db_pool.simple_select_list(
-                table="device_lists_remote_resync",
-                keyvalues=None,
-                retcols=("user_id",),
-                desc="get_user_ids_requiring_device_list_resync",
+            rows = cast(
+                List[Dict[str, str]],
+                await self.db_pool.simple_select_list(
+                    table="device_lists_remote_resync",
+                    keyvalues=None,
+                    retcols=("user_id",),
+                    desc="get_user_ids_requiring_device_list_resync",
+                ),
             )
 
-        return {row["user_id"] for row in rows}
+            return {row["user_id"] for row in rows}
 
     async def mark_remote_users_device_caches_as_stale(
         self, user_ids: StrCollection