diff --git a/changelog.d/16123.misc b/changelog.d/16123.misc
new file mode 100644
index 0000000000..b7c6b7c2f2
--- /dev/null
+++ b/changelog.d/16123.misc
@@ -0,0 +1 @@
+Add cache to `get_server_keys_json_for_remote`.
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8f3865d412..981fd1f58a 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json
@@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
+from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)
- store_queries = []
+ server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
- if not key_ids:
- key_ids = (None,)
- for key_id in key_ids:
- store_queries.append((server_name, key_id, None))
+ if key_ids:
+ results: Mapping[
+ str, Optional[FetchKeyResultForRemote]
+ ] = await self.store.get_server_keys_json_for_remote(
+ server_name, key_ids
+ )
+ else:
+ results = await self.store.get_all_server_keys_json_for_remote(
+ server_name
+ )
- cached = await self.store.get_server_keys_json_for_remote(store_queries)
+ server_keys.update(
+ ((server_name, key_id), res) for key_id, res in results.items()
+ )
json_results: Set[bytes] = set()
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), key_results in cached.items():
- results = [(result["ts_added_ms"], result) for result in key_results]
-
- if key_id is None:
+ for (server_name, key_id), key_result in server_keys.items():
+ if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
- for _, result in results:
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(result["key_json"]))
+ if key_result:
+ json_results.add(key_result.key_json)
continue
miss = False
- if not results:
+ if key_result is None:
miss = True
else:
- ts_added_ms, most_recent_result = max(results)
- ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+ ts_added_ms = key_result.added_ts
+ ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(most_recent_result["key_json"]))
+
+ json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index cea32a034a..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
import itertools
import json
import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- self._get_server_keys_json.invalidate(((server_name, key_id),))
+ await self.invalidate_cache_and_stream(
+ "_get_server_keys_json", ((server_name, key_id),)
+ )
+ await self.invalidate_cache_and_stream(
+ "get_server_key_json_for_remote", (server_name, key_id)
+ )
@cached()
def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
- async def get_server_keys_json_for_remote(
- self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- """Retrieve the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
+ @cached()
+ def get_server_key_json_for_remote(
+ self,
+ server_name: str,
+ key_id: str,
+ ) -> Optional[FetchKeyResultForRemote]:
+ raise NotImplementedError()
- Args:
- server_keys: List of (server_name, key_id, source) triplets.
+ @cachedList(
+ cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+ )
+ async def get_server_keys_json_for_remote(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ """Fetch the cached keys for the given server/key IDs.
- Returns:
- A mapping from (server_name, key_id, source) triplets to a list of dicts
+ If we have multiple entries for a given key ID, returns the most recent.
"""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
+ )
- def _get_server_keys_json_txn(
- txn: LoggingTransaction,
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
+ if not rows:
+ return {}
+
+ # We sort the rows so that the most recently added entry is picked up.
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
- return await self.db_pool.runInteraction(
- "get_server_keys_json", _get_server_keys_json_txn
+ async def get_all_server_keys_json_for_remote(
+ self,
+ server_name: str,
+ ) -> Dict[str, FetchKeyResultForRemote]:
+ """Fetch the cached keys for the given server.
+
+ If we have multiple entries for a given key ID, returns the most recent.
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
)
+
+ if not rows:
+ return {}
+
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 71584f3f74..e74b2269d2 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FetchKeyResultForRemote:
+ key_json: bytes # the full key JSON
+ valid_until_ts: int # how long we can use this key for, in milliseconds.
+ added_ts: int # When we added this key, in milliseconds.
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index fdfd4f911d..2be341ac7b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], SERVER_NAME)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""
|