summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 85c1778a81..1ca66d57d4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+    async def count_devices_by_users(
+        self, user_ids: Optional[Collection[str]] = None
+    ) -> int:
         """Retrieve number of all devices of given users.
         Only returns number of devices that are not marked as hidden.
 
@@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
 
         def count_devices_by_users_txn(
-            txn: LoggingTransaction, user_ids: List[str]
+            txn: LoggingTransaction, user_ids: Collection[str]
         ) -> int:
             sql = """
                 SELECT count(*)
@@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @cancellable
     async def get_user_devices_from_cache(
         self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
-    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
@@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
         # First fetch all the users which all devices are to be returned.
-        results: Dict[str, Dict[str, JsonDict]] = {}
+        results: Dict[str, Mapping[str, JsonDict]] = {}
         for user_id in user_ids:
             if user_id in user_ids_in_cache:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
         # Then fetch all device-specific requests, but skip users we've already
         # fetched all devices for.
+        device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_id in user_and_device_ids:
             if user_id in user_ids_in_cache and user_id not in user_ids:
                 device = await self._get_cached_user_device(user_id, device_id)
-                results.setdefault(user_id, {})[device_id] = device
+                device_specific_results.setdefault(user_id, {})[device_id] = device
+        results.update(device_specific_results)
 
         set_tag("in_cache", str(results))
         set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return db_to_json(content)
 
     @cached()
-    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+    async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
         devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},