diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index cea32a034a..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
import itertools
import json
import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- self._get_server_keys_json.invalidate(((server_name, key_id),))
+ await self.invalidate_cache_and_stream(
+ "_get_server_keys_json", ((server_name, key_id),)
+ )
+ await self.invalidate_cache_and_stream(
+ "get_server_key_json_for_remote", (server_name, key_id)
+ )
@cached()
def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
- async def get_server_keys_json_for_remote(
- self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- """Retrieve the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
+ @cached()
+ def get_server_key_json_for_remote(
+ self,
+ server_name: str,
+ key_id: str,
+ ) -> Optional[FetchKeyResultForRemote]:
+ raise NotImplementedError()
- Args:
- server_keys: List of (server_name, key_id, source) triplets.
+ @cachedList(
+ cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+ )
+ async def get_server_keys_json_for_remote(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ """Fetch the cached keys for the given server/key IDs.
- Returns:
- A mapping from (server_name, key_id, source) triplets to a list of dicts
+ If we have multiple entries for a given key ID, returns the most recent.
"""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
+ )
- def _get_server_keys_json_txn(
- txn: LoggingTransaction,
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
+ if not rows:
+ return {}
+
+ # We sort the rows so that the most recently added entry is picked up.
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
- return await self.db_pool.runInteraction(
- "get_server_keys_json", _get_server_keys_json_txn
+ async def get_all_server_keys_json_for_remote(
+ self,
+ server_name: str,
+ ) -> Dict[str, FetchKeyResultForRemote]:
+ """Fetch the cached keys for the given server.
+
+ If we have multiple entries for a given key ID, returns the most recent.
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
)
+
+ if not rows:
+ return {}
+
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
|