summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py47
2 files changed, 42 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index adde5d0978..7a6ed332aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
     @trace
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, str]]
+        self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..60f622ad71 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Union,
     cast,
+    overload,
 )
 
 import attr
 from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
 
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.appservice import (
@@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             user_devices = devices[user_id]
             results = []
             for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
+                result: JsonDict = {"device_id": device_id}
 
                 keys = device.keys
                 if keys:
@@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
                 r = device_info.keys
+                if r is None:
+                    continue
+
                 r["unsigned"] = {}
                 display_name = device_info.display_name
                 if display_name is not None:
@@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return rv
 
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[True],
+        include_deleted_devices: Literal[True],
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        ...
+
     @trace
     async def get_e2e_device_keys_and_signatures(
         self,
-        query_list: List[Tuple[str, Optional[str]]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
-    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+    ) -> Union[
+        Dict[str, Dict[str, DeviceKeyLookupResult]],
+        Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+    ]:
         """Fetch a list of device keys
 
         Any cross-signatures made on the keys by the owner of the device are also
@@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            row = await self.db_pool.runInteraction(
+            claim_row = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
@@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 algorithm,
                 db_autocommit=db_autocommit,
             )
-            if row:
+            if claim_row:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[row[0]] = row[1]
+                device_results[claim_row[0]] = claim_row[1]
                 continue
 
             # No one-time key available, so see if there's a fallback