diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8f3865d412..981fd1f58a 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json
@@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
+from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)
- store_queries = []
+ server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
- if not key_ids:
- key_ids = (None,)
- for key_id in key_ids:
- store_queries.append((server_name, key_id, None))
+ if key_ids:
+ results: Mapping[
+ str, Optional[FetchKeyResultForRemote]
+ ] = await self.store.get_server_keys_json_for_remote(
+ server_name, key_ids
+ )
+ else:
+ results = await self.store.get_all_server_keys_json_for_remote(
+ server_name
+ )
- cached = await self.store.get_server_keys_json_for_remote(store_queries)
+ server_keys.update(
+ ((server_name, key_id), res) for key_id, res in results.items()
+ )
json_results: Set[bytes] = set()
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), key_results in cached.items():
- results = [(result["ts_added_ms"], result) for result in key_results]
-
- if key_id is None:
+ for (server_name, key_id), key_result in server_keys.items():
+ if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
- for _, result in results:
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(result["key_json"]))
+ if key_result:
+ json_results.add(key_result.key_json)
continue
miss = False
- if not results:
+ if key_result is None:
miss = True
else:
- ts_added_ms, most_recent_result = max(results)
- ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+ ts_added_ms = key_result.added_ts
+ ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(most_recent_result["key_json"]))
+
+ json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist
|