summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py70
1 files changed, 40 insertions, 30 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fc23d18eba..0b75f6763a 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -283,7 +283,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             allow_none=True,
         )
 
-    async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
+    async def get_devices_by_user(
+        self, user_id: str
+    ) -> Dict[str, Dict[str, Optional[str]]]:
         """Retrieve all of a user's registered devices. Only returns devices
         that are not marked as hidden.
 
@@ -291,20 +293,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             user_id:
         Returns:
             A mapping from device_id to a dict containing "device_id", "user_id"
-            and "display_name" for each device.
+            and "display_name" for each device. Display name may be null.
         """
-        devices = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues={"user_id": user_id, "hidden": False},
-            retcols=("user_id", "device_id", "display_name"),
-            desc="get_devices_by_user",
+        devices = cast(
+            List[Tuple[str, str, Optional[str]]],
+            await self.db_pool.simple_select_list(
+                table="devices",
+                keyvalues={"user_id": user_id, "hidden": False},
+                retcols=("user_id", "device_id", "display_name"),
+                desc="get_devices_by_user",
+            ),
         )
 
-        return {d["device_id"]: d for d in devices}
+        return {
+            d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
+            for d in devices
+        }
 
     async def get_devices_by_auth_provider_session_id(
         self, auth_provider_id: str, auth_provider_session_id: str
-    ) -> List[Dict[str, Any]]:
+    ) -> List[Tuple[str, str]]:
         """Retrieve the list of devices associated with a SSO IdP session ID.
 
         Args:
@@ -313,14 +321,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         Returns:
             A list of dicts containing the device_id and the user_id of each device
         """
-        return await self.db_pool.simple_select_list(
-            table="device_auth_providers",
-            keyvalues={
-                "auth_provider_id": auth_provider_id,
-                "auth_provider_session_id": auth_provider_session_id,
-            },
-            retcols=("user_id", "device_id"),
-            desc="get_devices_by_auth_provider_session_id",
+        return cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_list(
+                table="device_auth_providers",
+                keyvalues={
+                    "auth_provider_id": auth_provider_id,
+                    "auth_provider_session_id": auth_provider_session_id,
+                },
+                retcols=("user_id", "device_id"),
+                desc="get_devices_by_auth_provider_session_id",
+            ),
         )
 
     @trace
@@ -821,15 +832,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     async def get_cached_devices_for_user(
         self, user_id: str
     ) -> Mapping[str, JsonMapping]:
-        devices = await self.db_pool.simple_select_list(
-            table="device_lists_remote_cache",
-            keyvalues={"user_id": user_id},
-            retcols=("device_id", "content"),
-            desc="get_cached_devices_for_user",
+        devices = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_list(
+                table="device_lists_remote_cache",
+                keyvalues={"user_id": user_id},
+                retcols=("device_id", "content"),
+                desc="get_cached_devices_for_user",
+            ),
         )
-        return {
-            device["device_id"]: db_to_json(device["content"]) for device in devices
-        }
+        return {device[0]: db_to_json(device[1]) for device in devices}
 
     def get_cached_device_list_changes(
         self,
@@ -1080,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             The IDs of users whose device lists need resync.
         """
         if user_ids:
-            row_tuples = cast(
+            rows = cast(
                 List[Tuple[str]],
                 await self.db_pool.simple_select_many_batch(
                     table="device_lists_remote_resync",
@@ -1090,11 +1102,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                     desc="get_user_ids_requiring_device_list_resync_with_iterable",
                 ),
             )
-
-            return {row[0] for row in row_tuples}
         else:
             rows = cast(
-                List[Dict[str, str]],
+                List[Tuple[str]],
                 await self.db_pool.simple_select_list(
                     table="device_lists_remote_resync",
                     keyvalues=None,
@@ -1103,7 +1113,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 ),
             )
 
-            return {row["user_id"] for row in rows}
+        return {row[0] for row in rows}
 
     async def mark_remote_users_device_caches_as_stale(
         self, user_ids: StrCollection