diff options
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/keys.py | 132 |
1 files changed, 88 insertions, 44 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index cea32a034a..a3b4744855 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,14 +16,13 @@ import itertools import json import logging -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Dict, Iterable, Mapping, Optional, Tuple from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction -from synapse.storage.keys import FetchKeyResult +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -34,7 +33,7 @@ logger = logging.getLogger(__name__) db_binary_type = memoryview -class KeyStore(SQLBaseStore): +class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" @cached() @@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore): # invalidate takes a tuple corresponding to the params of # _get_server_keys_json. _get_server_keys_json only takes one # param, which is itself the 2-tuple (server_name, key_id). - self._get_server_keys_json.invalidate(((server_name, key_id),)) + await self.invalidate_cache_and_stream( + "_get_server_keys_json", ((server_name, key_id),) + ) + await self.invalidate_cache_and_stream( + "get_server_key_json_for_remote", (server_name, key_id) + ) @cached() def _get_server_keys_json( @@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore): return await self.db_pool.runInteraction("get_server_keys_json", _txn) - async def get_server_keys_json_for_remote( - self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: - """Retrieve the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. + @cached() + def get_server_key_json_for_remote( + self, + server_name: str, + key_id: str, + ) -> Optional[FetchKeyResultForRemote]: + raise NotImplementedError() - Args: - server_keys: List of (server_name, key_id, source) triplets. + @cachedList( + cached_method_name="get_server_key_json_for_remote", list_name="key_ids" + ) + async def get_server_keys_json_for_remote( + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, Optional[FetchKeyResultForRemote]]: + """Fetch the cached keys for the given server/key IDs. - Returns: - A mapping from (server_name, key_id, source) triplets to a list of dicts + If we have multiple entries for a given key ID, returns the most recent. """ + rows = await self.db_pool.simple_select_many_batch( + table="server_keys_json", + column="key_id", + iterable=key_ids, + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", + ) - def _get_server_keys_json_txn( - txn: LoggingTransaction, - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: - results = {} - for server_name, key_id, from_server in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server - rows = self.db_pool.simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - results[(server_name, key_id, from_server)] = rows - return results + if not rows: + return {} + + # We sort the rows so that the most recently added entry is picked up. + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } - return await self.db_pool.runInteraction( - "get_server_keys_json", _get_server_keys_json_txn + async def get_all_server_keys_json_for_remote( + self, + server_name: str, + ) -> Dict[str, FetchKeyResultForRemote]: + """Fetch the cached keys for the given server. + + If we have multiple entries for a given key ID, returns the most recent. + """ + rows = await self.db_pool.simple_select_list( + table="server_keys_json", + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ) + + if not rows: + return {} + + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } |