diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a3b6c8ae8e..dc7768c50c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
@@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, str]]]:
+ ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ A tuple pf:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON.
+
+ A copy of the input which has not been fulfilled.
"""
@trace
@@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
- results: Dict[str, Dict[str, Dict[str, str]]] = {}
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ missing: List[Tuple[str, str, str]] = []
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
@@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- device_results[claim_row[0]] = claim_row[1]
- continue
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ else:
+ missing.append((user_id, device_id, algorithm))
+
+ return results, missing
+
+ async def claim_e2e_fallback_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Take a list of fallback keys out of the database.
- # No one-time key available, so see if there's a fallback
- # key
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON.
+ """
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ for user_id, device_id, algorithm in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
@@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
- device_results[f"{algorithm}:{key_id}"] = key_json
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
return results
|