summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-08-18 11:05:01 +0100
committerGitHub <noreply@github.com>2023-08-18 11:05:01 +0100
commit0aba4a4eaac778ad75509fe20733b27bfc86fd9d (patch)
tree209aec9751ab79b221d75ef7e9cb5c50391b3544 /synapse/rest
parentCache token introspection response from OIDC provider (#16117) (diff)
downloadsynapse-0aba4a4eaac778ad75509fe20733b27bfc86fd9d.tar.xz
Add cache to `get_server_keys_json_for_remote` (#16123)
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py44
1 files changed, 25 insertions, 19 deletions
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8f3865d412..981fd1f58a 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
 
 from signedjson.sign import sign_json
 
@@ -27,6 +27,7 @@ from synapse.http.servlet import (
     parse_integer,
     parse_json_object_from_request,
 )
+from synapse.storage.keys import FetchKeyResultForRemote
 from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
     ) -> JsonDict:
         logger.info("Handling query for keys %r", query)
 
-        store_queries = []
+        server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
         for server_name, key_ids in query.items():
-            if not key_ids:
-                key_ids = (None,)
-            for key_id in key_ids:
-                store_queries.append((server_name, key_id, None))
+            if key_ids:
+                results: Mapping[
+                    str, Optional[FetchKeyResultForRemote]
+                ] = await self.store.get_server_keys_json_for_remote(
+                    server_name, key_ids
+                )
+            else:
+                results = await self.store.get_all_server_keys_json_for_remote(
+                    server_name
+                )
 
-        cached = await self.store.get_server_keys_json_for_remote(store_queries)
+            server_keys.update(
+                ((server_name, key_id), res) for key_id, res in results.items()
+            )
 
         json_results: Set[bytes] = set()
 
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
         # Map server_name->key_id->int. Note that the value of the int is unused.
         # XXX: why don't we just use a set?
         cache_misses: Dict[str, Dict[str, int]] = {}
-        for (server_name, key_id, _), key_results in cached.items():
-            results = [(result["ts_added_ms"], result) for result in key_results]
-
-            if key_id is None:
+        for (server_name, key_id), key_result in server_keys.items():
+            if not query[server_name]:
                 # all keys were requested. Just return what we have without worrying
                 # about validity
-                for _, result in results:
-                    # Cast to bytes since postgresql returns a memoryview.
-                    json_results.add(bytes(result["key_json"]))
+                if key_result:
+                    json_results.add(key_result.key_json)
                 continue
 
             miss = False
-            if not results:
+            if key_result is None:
                 miss = True
             else:
-                ts_added_ms, most_recent_result = max(results)
-                ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+                ts_added_ms = key_result.added_ts
+                ts_valid_until_ms = key_result.valid_until_ts
                 req_key = query.get(server_name, {}).get(key_id, {})
                 req_valid_until = req_key.get("minimum_valid_until_ts")
                 if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
                         ts_valid_until_ms,
                         time_now_ms,
                     )
-                # Cast to bytes since postgresql returns a memoryview.
-                json_results.add(bytes(most_recent_result["key_json"]))
+
+                json_results.add(key_result.key_json)
 
             if miss and query_remote_on_cache_miss:
                 # only bother attempting to fetch keys from servers on our whitelist