summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py36
1 files changed, 27 insertions, 9 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a3b6c8ae8e..dc7768c50c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
@@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, str]]]:
+    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
             query_list: An iterable of tuples of (user ID, device ID, algorithm).
 
         Returns:
-            A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+            A tuple pf:
+                A map of user ID -> a map device ID -> a map of key ID -> JSON.
+
+                A copy of the input which has not been fulfilled.
         """
 
         @trace
@@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key_id, key_json = otk_row
             return f"{algorithm}:{key_id}", key_json
 
-        results: Dict[str, Dict[str, Dict[str, str]]] = {}
+        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        missing: List[Tuple[str, str, str]] = []
         for user_id, device_id, algorithm in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
@@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[claim_row[0]] = claim_row[1]
-                continue
+                device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+            else:
+                missing.append((user_id, device_id, algorithm))
+
+        return results, missing
+
+    async def claim_e2e_fallback_keys(
+        self, query_list: Iterable[Tuple[str, str, str]]
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Take a list of fallback keys out of the database.
 
-            # No one-time key available, so see if there's a fallback
-            # key
+        Args:
+            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of user ID -> a map device ID -> a map of key ID -> JSON.
+        """
+        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for user_id, device_id, algorithm in query_list:
             row = await self.db_pool.simple_select_one(
                 table="e2e_fallback_keys_json",
                 keyvalues={
@@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 )
 
             device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
-            device_results[f"{algorithm}:{key_id}"] = key_json
+            device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
 
         return results